This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new d872245f Remove trailing slash from table location when creating a
table (#702)
d872245f is described below
commit d872245f143b919b25fda90a7b7c7fb0729a402a
Author: Felix Scherz <[email protected]>
AuthorDate: Mon May 6 19:32:09 2024 +0200
Remove trailing slash from table location when creating a table (#702)
---
pyiceberg/catalog/__init__.py | 2 +-
pyiceberg/catalog/rest.py | 2 +
tests/catalog/test_base.py | 14 ++++
tests/catalog/test_dynamodb.py | 15 ++++
tests/catalog/test_glue.py | 16 ++++
tests/catalog/test_hive.py | 175 +++++++++++++++++++++++++++++++++++++++++
tests/catalog/test_rest.py | 25 ++++++
tests/catalog/test_sql.py | 22 ++++++
8 files changed, 270 insertions(+), 1 deletion(-)
diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py
index 18d803fe..5bb9ec27 100644
--- a/pyiceberg/catalog/__init__.py
+++ b/pyiceberg/catalog/__init__.py
@@ -779,7 +779,7 @@ class MetastoreCatalog(Catalog, ABC):
def _resolve_table_location(self, location: Optional[str], database_name:
str, table_name: str) -> str:
if not location:
return self._get_default_warehouse_location(database_name,
table_name)
- return location
+ return location.rstrip("/")
def _get_default_warehouse_location(self, database_name: str, table_name:
str) -> str:
database_properties = self.load_namespace_properties(database_name)
diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py
index 53e3f6a1..565d8091 100644
--- a/pyiceberg/catalog/rest.py
+++ b/pyiceberg/catalog/rest.py
@@ -519,6 +519,8 @@ class RestCatalog(Catalog):
fresh_sort_order = assign_fresh_sort_order_ids(sort_order,
iceberg_schema, fresh_schema)
namespace_and_table = self._split_identifier_for_path(identifier)
+ if location:
+ location = location.rstrip("/")
request = CreateTableRequest(
name=namespace_and_table["table"],
location=location,
diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py
index 7d5e0a97..06e9a8a3 100644
--- a/tests/catalog/test_base.py
+++ b/tests/catalog/test_base.py
@@ -105,6 +105,7 @@ class InMemoryCatalog(MetastoreCatalog):
if not location:
location = f'{self._warehouse_location}/{"/".join(identifier)}'
+ location = location.rstrip("/")
metadata_location = self._get_metadata_location(location=location)
metadata = new_table_metadata(
@@ -353,6 +354,19 @@ def test_create_table_location_override(catalog:
InMemoryCatalog) -> None:
assert table.location() == new_location
+def test_create_table_removes_trailing_slash_from_location(catalog:
InMemoryCatalog) -> None:
+ new_location = f"{catalog._warehouse_location}/new_location"
+ table = catalog.create_table(
+ identifier=TEST_TABLE_IDENTIFIER,
+ schema=TEST_TABLE_SCHEMA,
+ location=f"{new_location}/",
+ partition_spec=TEST_TABLE_PARTITION_SPEC,
+ properties=TEST_TABLE_PROPERTIES,
+ )
+ assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table
+ assert table.location() == new_location
+
+
@pytest.mark.parametrize(
"schema,expected",
[
diff --git a/tests/catalog/test_dynamodb.py b/tests/catalog/test_dynamodb.py
index 1c647cf8..f4b16d34 100644
--- a/tests/catalog/test_dynamodb.py
+++ b/tests/catalog/test_dynamodb.py
@@ -117,6 +117,21 @@ def test_create_table_with_given_location(
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+@mock_aws
+def test_create_table_removes_trailing_slash_in_location(
+ _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested:
Schema, database_name: str, table_name: str
+) -> None:
+ catalog_name = "test_ddb_catalog"
+ identifier = (database_name, table_name)
+ test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint":
moto_endpoint_url})
+ test_catalog.create_namespace(namespace=database_name)
+ location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+ table = test_catalog.create_table(identifier=identifier,
schema=table_schema_nested, location=f"{location}/")
+ assert table.identifier == (catalog_name,) + identifier
+ assert table.location() == location
+ assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+
+
@mock_aws
def test_create_table_with_no_location(
_bucket_initialize: None, table_schema_nested: Schema, database_name: str,
table_name: str
diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py
index 5999b192..5b67b92c 100644
--- a/tests/catalog/test_glue.py
+++ b/tests/catalog/test_glue.py
@@ -137,6 +137,22 @@ def test_create_table_with_given_location(
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+@mock_aws
+def test_create_table_removes_trailing_slash_in_location(
+ _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested:
Schema, database_name: str, table_name: str
+) -> None:
+ catalog_name = "glue"
+ identifier = (database_name, table_name)
+ test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint":
moto_endpoint_url})
+ test_catalog.create_namespace(namespace=database_name)
+ location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+ table = test_catalog.create_table(identifier=identifier,
schema=table_schema_nested, location=f"{location}/")
+ assert table.identifier == (catalog_name,) + identifier
+ assert table.location() == location
+ assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+ assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+
+
@mock_aws
def test_create_table_with_pyarrow_schema(
_bucket_initialize: None,
diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py
index 70927ea1..af3a3801 100644
--- a/tests/catalog/test_hive.py
+++ b/tests/catalog/test_hive.py
@@ -365,6 +365,181 @@ def test_create_table(
assert metadata.model_dump() == expected.model_dump()
[email protected]("hive2_compatible", [True, False])
+@patch("time.time", MagicMock(return_value=12345))
+def test_create_table_with_given_location_removes_trailing_slash(
+ table_schema_with_all_types: Schema, hive_database: HiveDatabase,
hive_table: HiveTable, hive2_compatible: bool
+) -> None:
+ catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
+ if hive2_compatible:
+ catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL,
**{"hive.hive2-compatible": "true"})
+
+ location = f"{hive_database.locationUri}/table-given-location"
+
+ catalog._client = MagicMock()
+ catalog._client.__enter__().create_table.return_value = None
+ catalog._client.__enter__().get_table.return_value = hive_table
+ catalog._client.__enter__().get_database.return_value = hive_database
+ catalog.create_table(
+ ("default", "table"), schema=table_schema_with_all_types,
properties={"owner": "javaberg"}, location=f"{location}/"
+ )
+
+ called_hive_table: HiveTable =
catalog._client.__enter__().create_table.call_args[0][0]
+ # This one is generated within the function itself, so we need to extract
+ # it to construct the assert_called_with
+ metadata_location: str = called_hive_table.parameters["metadata_location"]
+ assert metadata_location.endswith(".metadata.json")
+ assert "/database/table-given-location/metadata/" in metadata_location
+ catalog._client.__enter__().create_table.assert_called_with(
+ HiveTable(
+ tableName="table",
+ dbName="default",
+ owner="javaberg",
+ createTime=12345,
+ lastAccessTime=12345,
+ retention=None,
+ sd=StorageDescriptor(
+ cols=[
+ FieldSchema(name='boolean', type='boolean', comment=None),
+ FieldSchema(name='integer', type='int', comment=None),
+ FieldSchema(name='long', type='bigint', comment=None),
+ FieldSchema(name='float', type='float', comment=None),
+ FieldSchema(name='double', type='double', comment=None),
+ FieldSchema(name='decimal', type='decimal(32,3)',
comment=None),
+ FieldSchema(name='date', type='date', comment=None),
+ FieldSchema(name='time', type='string', comment=None),
+ FieldSchema(name='timestamp', type='timestamp',
comment=None),
+ FieldSchema(
+ name='timestamptz',
+ type='timestamp' if hive2_compatible else 'timestamp
with local time zone',
+ comment=None,
+ ),
+ FieldSchema(name='string', type='string', comment=None),
+ FieldSchema(name='uuid', type='string', comment=None),
+ FieldSchema(name='fixed', type='binary', comment=None),
+ FieldSchema(name='binary', type='binary', comment=None),
+ FieldSchema(name='list', type='array<string>',
comment=None),
+ FieldSchema(name='map', type='map<string,int>',
comment=None),
+ FieldSchema(name='struct',
type='struct<inner_string:string,inner_int:int>', comment=None),
+ ],
+ location=f"{hive_database.locationUri}/table-given-location",
+ inputFormat="org.apache.hadoop.mapred.FileInputFormat",
+ outputFormat="org.apache.hadoop.mapred.FileOutputFormat",
+ compressed=None,
+ numBuckets=None,
+ serdeInfo=SerDeInfo(
+ name=None,
+
serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
+ parameters=None,
+ description=None,
+ serializerClass=None,
+ deserializerClass=None,
+ serdeType=None,
+ ),
+ bucketCols=None,
+ sortCols=None,
+ parameters=None,
+ skewedInfo=None,
+ storedAsSubDirectories=None,
+ ),
+ partitionKeys=None,
+ parameters={"EXTERNAL": "TRUE", "table_type": "ICEBERG",
"metadata_location": metadata_location},
+ viewOriginalText=None,
+ viewExpandedText=None,
+ tableType="EXTERNAL_TABLE",
+ privileges=None,
+ temporary=False,
+ rewriteEnabled=None,
+ creationMetadata=None,
+ catName=None,
+ ownerType=1,
+ writeId=-1,
+ isStatsCompliant=None,
+ colStats=None,
+ accessType=None,
+ requiredReadCapabilities=None,
+ requiredWriteCapabilities=None,
+ id=None,
+ fileMetadata=None,
+ dictionary=None,
+ txnId=None,
+ )
+ )
+
+ with open(metadata_location, encoding=UTF8) as f:
+ payload = f.read()
+
+ metadata = TableMetadataUtil.parse_raw(payload)
+
+ assert "database/table-given-location" in metadata.location
+
+ expected = TableMetadataV2(
+ location=metadata.location,
+ table_uuid=metadata.table_uuid,
+ last_updated_ms=metadata.last_updated_ms,
+ last_column_id=22,
+ schemas=[
+ Schema(
+ NestedField(field_id=1, name='boolean',
field_type=BooleanType(), required=True),
+ NestedField(field_id=2, name='integer',
field_type=IntegerType(), required=True),
+ NestedField(field_id=3, name='long', field_type=LongType(),
required=True),
+ NestedField(field_id=4, name='float', field_type=FloatType(),
required=True),
+ NestedField(field_id=5, name='double',
field_type=DoubleType(), required=True),
+ NestedField(field_id=6, name='decimal',
field_type=DecimalType(precision=32, scale=3), required=True),
+ NestedField(field_id=7, name='date', field_type=DateType(),
required=True),
+ NestedField(field_id=8, name='time', field_type=TimeType(),
required=True),
+ NestedField(field_id=9, name='timestamp',
field_type=TimestampType(), required=True),
+ NestedField(field_id=10, name='timestamptz',
field_type=TimestamptzType(), required=True),
+ NestedField(field_id=11, name='string',
field_type=StringType(), required=True),
+ NestedField(field_id=12, name='uuid', field_type=UUIDType(),
required=True),
+ NestedField(field_id=13, name='fixed',
field_type=FixedType(length=12), required=True),
+ NestedField(field_id=14, name='binary',
field_type=BinaryType(), required=True),
+ NestedField(
+ field_id=15,
+ name='list',
+ field_type=ListType(type='list', element_id=18,
element_type=StringType(), element_required=True),
+ required=True,
+ ),
+ NestedField(
+ field_id=16,
+ name='map',
+ field_type=MapType(
+ type='map', key_id=19, key_type=StringType(),
value_id=20, value_type=IntegerType(), value_required=True
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=17,
+ name='struct',
+ field_type=StructType(
+ NestedField(field_id=21, name='inner_string',
field_type=StringType(), required=False),
+ NestedField(field_id=22, name='inner_int',
field_type=IntegerType(), required=True),
+ ),
+ required=False,
+ ),
+ schema_id=0,
+ identifier_field_ids=[2],
+ )
+ ],
+ current_schema_id=0,
+ last_partition_id=999,
+ properties={"owner": "javaberg", 'write.parquet.compression-codec':
'zstd'},
+ partition_specs=[PartitionSpec()],
+ default_spec_id=0,
+ current_snapshot_id=None,
+ snapshots=[],
+ snapshot_log=[],
+ metadata_log=[],
+ sort_orders=[SortOrder(order_id=0)],
+ default_sort_order_id=0,
+ refs={},
+ format_version=2,
+ last_sequence_number=0,
+ )
+
+ assert metadata.model_dump() == expected.model_dump()
+
+
@patch("time.time", MagicMock(return_value=12345))
def test_create_v1_table(table_schema_simple: Schema, hive_database:
HiveDatabase, hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py
index 15ddb01b..b8410d68 100644
--- a/tests/catalog/test_rest.py
+++ b/tests/catalog/test_rest.py
@@ -732,6 +732,31 @@ def test_create_table_200(
assert actual == expected
+def test_create_table_with_given_location_removes_trailing_slash_200(
+ rest_mock: Mocker, table_schema_simple: Schema,
example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any]
+) -> None:
+ rest_mock.post(
+ f"{TEST_URI}v1/namespaces/fokko/tables",
+ json=example_table_metadata_no_snapshot_v1_rest_json,
+ status_code=200,
+ request_headers=TEST_HEADERS,
+ )
+ catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN)
+ location = "s3://warehouse/database/table-custom-location"
+ catalog.create_table(
+ identifier=("fokko", "fokko2"),
+ schema=table_schema_simple,
+ location=f"{location}/",
+ partition_spec=PartitionSpec(
+ PartitionField(source_id=1, field_id=1000,
transform=TruncateTransform(width=3), name="id"), spec_id=1
+ ),
+ sort_order=SortOrder(SortField(source_id=2,
transform=IdentityTransform())),
+ properties={"owner": "fokko"},
+ )
+ assert rest_mock.last_request
+ assert rest_mock.last_request.json()["location"] == location
+
+
def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) ->
None:
rest_mock.post(
f"{TEST_URI}v1/namespaces/fokko/tables",
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 40a1566e..97965268 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -264,6 +264,28 @@ def test_create_table_with_default_warehouse_location(
catalog.drop_table(random_identifier)
[email protected](
+ 'catalog',
+ [
+ lazy_fixture('catalog_memory'),
+ lazy_fixture('catalog_sqlite'),
+ ],
+)
+def test_create_table_with_given_location_removes_trailing_slash(
+ warehouse: Path, catalog: SqlCatalog, table_schema_nested: Schema,
random_identifier: Identifier
+) -> None:
+ database_name, table_name = random_identifier
+ location = f"file://{warehouse}/{database_name}.db/{table_name}-given"
+ catalog.create_namespace(database_name)
+ catalog.create_table(random_identifier, table_schema_nested,
location=f"{location}/")
+ table = catalog.load_table(random_identifier)
+ assert table.identifier == (catalog.name,) + random_identifier
+ assert table.metadata_location.startswith(f"file://{warehouse}")
+ assert os.path.exists(table.metadata_location[len("file://") :])
+ assert table.location() == location
+ catalog.drop_table(random_identifier)
+
+
@pytest.mark.parametrize(
'catalog',
[