This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 07442cc0 [Bug Fix] Allow HiveCatalog to create table with
TimestamptzType (#585)
07442cc0 is described below
commit 07442cc00125ac00a91f717d9832e5479f4ff6dd
Author: Honah J <[email protected]>
AuthorDate: Mon Apr 8 00:11:34 2024 -0700
[Bug Fix] Allow HiveCatalog to create table with TimestamptzType (#585)
---
mkdocs/docs/configuration.md | 9 +++
pyiceberg/catalog/glue.py | 5 +-
pyiceberg/catalog/hive.py | 45 +++++++-------
pyiceberg/table/__init__.py | 6 ++
tests/catalog/test_hive.py | 88 +++++++++++++++++++++++++---
tests/conftest.py | 74 +++++++++++++++++++++++
tests/integration/test_writes/test_writes.py | 20 +++++++
7 files changed, 216 insertions(+), 31 deletions(-)
diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md
index 93b198c3..1ca071f0 100644
--- a/mkdocs/docs/configuration.md
+++ b/mkdocs/docs/configuration.md
@@ -232,6 +232,15 @@ catalog:
s3.secret-access-key: password
```
+When using Hive 2.x, make sure to set the compatibility flag:
+
+```yaml
+catalog:
+ default:
+...
+ hive.hive2-compatible: true
+```
+
## Glue Catalog
Your AWS credentials can be passed directly through the Python API.
diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py
index e7532677..c3c2fdaf 100644
--- a/pyiceberg/catalog/glue.py
+++ b/pyiceberg/catalog/glue.py
@@ -65,6 +65,7 @@ from pyiceberg.serializers import FromInputFile
from pyiceberg.table import (
CommitTableRequest,
CommitTableResponse,
+ PropertyUtil,
Table,
update_table_metadata,
)
@@ -162,7 +163,7 @@ class _IcebergSchemaToGlueType(SchemaVisitor[str]):
if isinstance(primitive, DecimalType):
return f"decimal({primitive.precision},{primitive.scale})"
if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
- return str(primitive_type.root)
+ return str(primitive)
return GLUE_PRIMITIVE_TYPES[primitive_type]
@@ -344,7 +345,7 @@ class GlueCatalog(MetastoreCatalog):
self.glue.update_table(
DatabaseName=database_name,
TableInput=table_input,
- SkipArchive=self.properties.get(GLUE_SKIP_ARCHIVE,
GLUE_SKIP_ARCHIVE_DEFAULT),
+ SkipArchive=PropertyUtil.property_as_bool(self.properties,
GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT),
VersionId=version_id,
)
except self.glue.exceptions.EntityNotFoundException as e:
diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py
index 359bdef5..b504da9a 100644
--- a/pyiceberg/catalog/hive.py
+++ b/pyiceberg/catalog/hive.py
@@ -74,7 +74,7 @@ from pyiceberg.io import FileIO, load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.serializers import FromInputFile
-from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table,
TableProperties, update_table_metadata
+from pyiceberg.table import CommitTableRequest, CommitTableResponse,
PropertyUtil, Table, TableProperties, update_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
@@ -95,6 +95,7 @@ from pyiceberg.types import (
StringType,
StructType,
TimestampType,
+ TimestamptzType,
TimeType,
UUIDType,
)
@@ -103,25 +104,13 @@ if TYPE_CHECKING:
import pyarrow as pa
-# Replace by visitor
-hive_types = {
- BooleanType: "boolean",
- IntegerType: "int",
- LongType: "bigint",
- FloatType: "float",
- DoubleType: "double",
- DateType: "date",
- TimeType: "string",
- TimestampType: "timestamp",
- StringType: "string",
- UUIDType: "string",
- BinaryType: "binary",
- FixedType: "binary",
-}
-
COMMENT = "comment"
OWNER = "owner"
+# If set to true, HiveCatalog will operate in Hive2 compatibility mode
+HIVE2_COMPATIBLE = "hive.hive2-compatible"
+HIVE2_COMPATIBLE_DEFAULT = False
+
class _HiveClient:
"""Helper class to nicely open and close the transport."""
@@ -151,10 +140,15 @@ class _HiveClient:
self._transport.close()
-def _construct_hive_storage_descriptor(schema: Schema, location:
Optional[str]) -> StorageDescriptor:
+def _construct_hive_storage_descriptor(
+ schema: Schema, location: Optional[str], hive2_compatible: bool = False
+) -> StorageDescriptor:
ser_de_info =
SerDeInfo(serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
return StorageDescriptor(
- [FieldSchema(field.name, visit(field.field_type,
SchemaToHiveConverter()), field.doc) for field in schema.fields],
+ [
+ FieldSchema(field.name, visit(field.field_type,
SchemaToHiveConverter(hive2_compatible)), field.doc)
+ for field in schema.fields
+ ],
location,
"org.apache.hadoop.mapred.FileInputFormat",
"org.apache.hadoop.mapred.FileOutputFormat",
@@ -199,6 +193,7 @@ HIVE_PRIMITIVE_TYPES = {
DateType: "date",
TimeType: "string",
TimestampType: "timestamp",
+ TimestamptzType: "timestamp with local time zone",
StringType: "string",
UUIDType: "string",
BinaryType: "binary",
@@ -207,6 +202,11 @@ HIVE_PRIMITIVE_TYPES = {
class SchemaToHiveConverter(SchemaVisitor[str]):
+ hive2_compatible: bool
+
+ def __init__(self, hive2_compatible: bool):
+ self.hive2_compatible = hive2_compatible
+
def schema(self, schema: Schema, struct_result: str) -> str:
return struct_result
@@ -226,6 +226,9 @@ class SchemaToHiveConverter(SchemaVisitor[str]):
def primitive(self, primitive: PrimitiveType) -> str:
if isinstance(primitive, DecimalType):
return f"decimal({primitive.precision},{primitive.scale})"
+ elif self.hive2_compatible and isinstance(primitive, TimestamptzType):
+ # Hive2 doesn't support timestamp with local time zone
+ return "timestamp"
else:
return HIVE_PRIMITIVE_TYPES[type(primitive)]
@@ -314,7 +317,9 @@ class HiveCatalog(MetastoreCatalog):
owner=properties[OWNER] if properties and OWNER in properties else
getpass.getuser(),
createTime=current_time_millis // 1000,
lastAccessTime=current_time_millis // 1000,
- sd=_construct_hive_storage_descriptor(schema, location),
+ sd=_construct_hive_storage_descriptor(
+ schema, location,
PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE,
HIVE2_COMPATIBLE_DEFAULT)
+ ),
tableType=EXTERNAL_TABLE,
parameters=_construct_parameters(metadata_location),
)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 2dbc32d8..ac19c1a5 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -251,6 +251,12 @@ class PropertyUtil:
else:
return default
+ @staticmethod
+ def property_as_bool(properties: Dict[str, str], property_name: str,
default: bool) -> bool:
+ if value := properties.get(property_name):
+ return value.lower() == "true"
+ return default
+
class Transaction:
_table: Table
diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py
index e59b7599..a8c904d6 100644
--- a/tests/catalog/test_hive.py
+++ b/tests/catalog/test_hive.py
@@ -61,11 +61,24 @@ from pyiceberg.table.sorting import (
from pyiceberg.transforms import BucketTransform, IdentityTransform
from pyiceberg.typedef import UTF8
from pyiceberg.types import (
+ BinaryType,
BooleanType,
+ DateType,
+ DecimalType,
+ DoubleType,
+ FixedType,
+ FloatType,
IntegerType,
+ ListType,
LongType,
+ MapType,
NestedField,
StringType,
+ StructType,
+ TimestampType,
+ TimestamptzType,
+ TimeType,
+ UUIDType,
)
HIVE_CATALOG_NAME = "hive"
@@ -181,15 +194,20 @@ def test_check_number_of_namespaces(table_schema_simple:
Schema) -> None:
catalog.create_table("table", schema=table_schema_simple)
[email protected]("hive2_compatible", [True, False])
@patch("time.time", MagicMock(return_value=12345))
-def test_create_table(table_schema_simple: Schema, hive_database:
HiveDatabase, hive_table: HiveTable) -> None:
+def test_create_table(
+ table_schema_with_all_types: Schema, hive_database: HiveDatabase,
hive_table: HiveTable, hive2_compatible: bool
+) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
+ if hive2_compatible:
+ catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL,
**{"hive.hive2-compatible": "true"})
catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
- catalog.create_table(("default", "table"), schema=table_schema_simple,
properties={"owner": "javaberg"})
+ catalog.create_table(("default", "table"),
schema=table_schema_with_all_types, properties={"owner": "javaberg"})
called_hive_table: HiveTable =
catalog._client.__enter__().create_table.call_args[0][0]
# This one is generated within the function itself, so we need to extract
@@ -207,9 +225,27 @@ def test_create_table(table_schema_simple: Schema,
hive_database: HiveDatabase,
retention=None,
sd=StorageDescriptor(
cols=[
- FieldSchema(name="foo", type="string", comment=None),
- FieldSchema(name="bar", type="int", comment=None),
- FieldSchema(name="baz", type="boolean", comment=None),
+ FieldSchema(name='boolean', type='boolean', comment=None),
+ FieldSchema(name='integer', type='int', comment=None),
+ FieldSchema(name='long', type='bigint', comment=None),
+ FieldSchema(name='float', type='float', comment=None),
+ FieldSchema(name='double', type='double', comment=None),
+ FieldSchema(name='decimal', type='decimal(32,3)',
comment=None),
+ FieldSchema(name='date', type='date', comment=None),
+ FieldSchema(name='time', type='string', comment=None),
+ FieldSchema(name='timestamp', type='timestamp',
comment=None),
+ FieldSchema(
+ name='timestamptz',
+ type='timestamp' if hive2_compatible else 'timestamp
with local time zone',
+ comment=None,
+ ),
+ FieldSchema(name='string', type='string', comment=None),
+ FieldSchema(name='uuid', type='string', comment=None),
+ FieldSchema(name='fixed', type='binary', comment=None),
+ FieldSchema(name='binary', type='binary', comment=None),
+ FieldSchema(name='list', type='array<string>',
comment=None),
+ FieldSchema(name='map', type='map<string,int>',
comment=None),
+ FieldSchema(name='struct',
type='struct<inner_string:string,inner_int:int>', comment=None),
],
location=f"{hive_database.locationUri}/table",
inputFormat="org.apache.hadoop.mapred.FileInputFormat",
@@ -266,12 +302,46 @@ def test_create_table(table_schema_simple: Schema,
hive_database: HiveDatabase,
location=metadata.location,
table_uuid=metadata.table_uuid,
last_updated_ms=metadata.last_updated_ms,
- last_column_id=3,
+ last_column_id=22,
schemas=[
Schema(
- NestedField(field_id=1, name="foo", field_type=StringType(),
required=False),
- NestedField(field_id=2, name="bar", field_type=IntegerType(),
required=True),
- NestedField(field_id=3, name="baz", field_type=BooleanType(),
required=False),
+ NestedField(field_id=1, name='boolean',
field_type=BooleanType(), required=True),
+ NestedField(field_id=2, name='integer',
field_type=IntegerType(), required=True),
+ NestedField(field_id=3, name='long', field_type=LongType(),
required=True),
+ NestedField(field_id=4, name='float', field_type=FloatType(),
required=True),
+ NestedField(field_id=5, name='double',
field_type=DoubleType(), required=True),
+ NestedField(field_id=6, name='decimal',
field_type=DecimalType(precision=32, scale=3), required=True),
+ NestedField(field_id=7, name='date', field_type=DateType(),
required=True),
+ NestedField(field_id=8, name='time', field_type=TimeType(),
required=True),
+ NestedField(field_id=9, name='timestamp',
field_type=TimestampType(), required=True),
+ NestedField(field_id=10, name='timestamptz',
field_type=TimestamptzType(), required=True),
+ NestedField(field_id=11, name='string',
field_type=StringType(), required=True),
+ NestedField(field_id=12, name='uuid', field_type=UUIDType(),
required=True),
+ NestedField(field_id=13, name='fixed',
field_type=FixedType(length=12), required=True),
+ NestedField(field_id=14, name='binary',
field_type=BinaryType(), required=True),
+ NestedField(
+ field_id=15,
+ name='list',
+ field_type=ListType(type='list', element_id=18,
element_type=StringType(), element_required=True),
+ required=True,
+ ),
+ NestedField(
+ field_id=16,
+ name='map',
+ field_type=MapType(
+ type='map', key_id=19, key_type=StringType(),
value_id=20, value_type=IntegerType(), value_required=True
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=17,
+ name='struct',
+ field_type=StructType(
+ NestedField(field_id=21, name='inner_string',
field_type=StringType(), required=False),
+ NestedField(field_id=22, name='inner_int',
field_type=IntegerType(), required=True),
+ ),
+ required=True,
+ ),
schema_id=0,
identifier_field_ids=[2],
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 4a820fed..7da0a0a8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -69,7 +69,9 @@ from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
+ DecimalType,
DoubleType,
+ FixedType,
FloatType,
IntegerType,
ListType,
@@ -78,6 +80,9 @@ from pyiceberg.types import (
NestedField,
StringType,
StructType,
+ TimestampType,
+ TimestamptzType,
+ TimeType,
UUIDType,
)
from pyiceberg.utils.datetime import datetime_to_millis
@@ -266,6 +271,54 @@ def table_schema_nested_with_struct_key_map() -> Schema:
)
[email protected](scope="session")
+def table_schema_with_all_types() -> Schema:
+ return schema.Schema(
+ NestedField(field_id=1, name="boolean", field_type=BooleanType(),
required=True),
+ NestedField(field_id=2, name="integer", field_type=IntegerType(),
required=True),
+ NestedField(field_id=3, name="long", field_type=LongType(),
required=True),
+ NestedField(field_id=4, name="float", field_type=FloatType(),
required=True),
+ NestedField(field_id=5, name="double", field_type=DoubleType(),
required=True),
+ NestedField(field_id=6, name="decimal", field_type=DecimalType(32, 3),
required=True),
+ NestedField(field_id=7, name="date", field_type=DateType(),
required=True),
+ NestedField(field_id=8, name="time", field_type=TimeType(),
required=True),
+ NestedField(field_id=9, name="timestamp", field_type=TimestampType(),
required=True),
+ NestedField(field_id=10, name="timestamptz",
field_type=TimestamptzType(), required=True),
+ NestedField(field_id=11, name="string", field_type=StringType(),
required=True),
+ NestedField(field_id=12, name="uuid", field_type=UUIDType(),
required=True),
+ NestedField(field_id=14, name="fixed", field_type=FixedType(12),
required=True),
+ NestedField(field_id=13, name="binary", field_type=BinaryType(),
required=True),
+ NestedField(
+ field_id=15,
+ name="list",
+ field_type=ListType(element_id=16, element_type=StringType(),
element_required=True),
+ required=True,
+ ),
+ NestedField(
+ field_id=17,
+ name="map",
+ field_type=MapType(
+ key_id=18,
+ key_type=StringType(),
+ value_id=19,
+ value_type=IntegerType(),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=20,
+ name="struct",
+ field_type=StructType(
+ NestedField(field_id=21, name="inner_string",
field_type=StringType(), required=False),
+ NestedField(field_id=22, name="inner_int",
field_type=IntegerType(), required=True),
+ ),
+ ),
+ schema_id=1,
+ identifier_field_ids=[2],
+ )
+
+
@pytest.fixture(scope="session")
def pyarrow_schema_simple_without_ids() -> "pa.Schema":
import pyarrow as pa
@@ -1953,6 +2006,20 @@ def session_catalog() -> Catalog:
)
[email protected](scope="session")
+def session_catalog_hive() -> Catalog:
+ return load_catalog(
+ "local",
+ **{
+ "type": "hive",
+ "uri": "http://localhost:9083",
+ "s3.endpoint": "http://localhost:9000",
+ "s3.access-key-id": "admin",
+ "s3.secret-access-key": "password",
+ },
+ )
+
+
@pytest.fixture(scope="session")
def spark() -> "SparkSession":
import importlib.metadata
@@ -1984,6 +2051,13 @@ def spark() -> "SparkSession":
.config("spark.sql.catalog.integration.s3.endpoint",
"http://localhost:9000")
.config("spark.sql.catalog.integration.s3.path-style-access", "true")
.config("spark.sql.defaultCatalog", "integration")
+ .config("spark.sql.catalog.hive",
"org.apache.iceberg.spark.SparkCatalog")
+ .config("spark.sql.catalog.hive.type", "hive")
+ .config("spark.sql.catalog.hive.uri", "http://localhost:9083")
+ .config("spark.sql.catalog.hive.io-impl",
"org.apache.iceberg.aws.s3.S3FileIO")
+ .config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/")
+ .config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000")
+ .config("spark.sql.catalog.hive.s3.path-style-access", "true")
.getOrCreate()
)
diff --git a/tests/integration/test_writes/test_writes.py
b/tests/integration/test_writes/test_writes.py
index 62d3bb11..e1526d2a 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -33,6 +33,7 @@ from pyspark.sql import SparkSession
from pytest_mock.plugin import MockerFixture
from pyiceberg.catalog import Catalog
+from pyiceberg.catalog.hive import HiveCatalog
from pyiceberg.catalog.sql import SqlCatalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.table import TableProperties, _dataframe_to_data_files
@@ -747,3 +748,22 @@ def test_write_within_transaction(spark: SparkSession,
session_catalog: Catalog,
tbl.transaction().set_properties({"test": "2"}).commit_transaction()
tbl.append(arrow_table_with_null)
assert get_metadata_entries_count(identifier) == 4
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_hive_catalog_storage_descriptor(
+ session_catalog_hive: HiveCatalog,
+ pa_schema: pa.Schema,
+ arrow_table_with_null: pa.Table,
+ spark: SparkSession,
+ format_version: int,
+) -> None:
+ tbl = _create_table(
+ session_catalog_hive, "default.test_storage_descriptor",
{"format-version": format_version}, [arrow_table_with_null]
+ )
+
+ # check if pyiceberg can read the table
+ assert len(tbl.scan().to_arrow()) == 3
+ # check if spark can read the table
+ assert spark.sql("SELECT * FROM
hive.default.test_storage_descriptor").count() == 3