This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 07442cc0 [Bug Fix] Allow HiveCatalog to create table with 
TimestamptzType (#585)
07442cc0 is described below

commit 07442cc00125ac00a91f717d9832e5479f4ff6dd
Author: Honah J <[email protected]>
AuthorDate: Mon Apr 8 00:11:34 2024 -0700

    [Bug Fix] Allow HiveCatalog to create table with TimestamptzType (#585)
---
 mkdocs/docs/configuration.md                 |  9 +++
 pyiceberg/catalog/glue.py                    |  5 +-
 pyiceberg/catalog/hive.py                    | 45 +++++++-------
 pyiceberg/table/__init__.py                  |  6 ++
 tests/catalog/test_hive.py                   | 88 +++++++++++++++++++++++++---
 tests/conftest.py                            | 74 +++++++++++++++++++++++
 tests/integration/test_writes/test_writes.py | 20 +++++++
 7 files changed, 216 insertions(+), 31 deletions(-)

diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md
index 93b198c3..1ca071f0 100644
--- a/mkdocs/docs/configuration.md
+++ b/mkdocs/docs/configuration.md
@@ -232,6 +232,15 @@ catalog:
     s3.secret-access-key: password
 ```
 
+When using Hive 2.x, make sure to set the compatibility flag:
+
+```yaml
+catalog:
+  default:
+...
+    hive.hive2-compatible: true
+```
+
 ## Glue Catalog
 
 Your AWS credentials can be passed directly through the Python API.
diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py
index e7532677..c3c2fdaf 100644
--- a/pyiceberg/catalog/glue.py
+++ b/pyiceberg/catalog/glue.py
@@ -65,6 +65,7 @@ from pyiceberg.serializers import FromInputFile
 from pyiceberg.table import (
     CommitTableRequest,
     CommitTableResponse,
+    PropertyUtil,
     Table,
     update_table_metadata,
 )
@@ -162,7 +163,7 @@ class _IcebergSchemaToGlueType(SchemaVisitor[str]):
         if isinstance(primitive, DecimalType):
             return f"decimal({primitive.precision},{primitive.scale})"
         if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
-            return str(primitive_type.root)
+            return str(primitive)
         return GLUE_PRIMITIVE_TYPES[primitive_type]
 
 
@@ -344,7 +345,7 @@ class GlueCatalog(MetastoreCatalog):
             self.glue.update_table(
                 DatabaseName=database_name,
                 TableInput=table_input,
-                SkipArchive=self.properties.get(GLUE_SKIP_ARCHIVE, 
GLUE_SKIP_ARCHIVE_DEFAULT),
+                SkipArchive=PropertyUtil.property_as_bool(self.properties, 
GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT),
                 VersionId=version_id,
             )
         except self.glue.exceptions.EntityNotFoundException as e:
diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py
index 359bdef5..b504da9a 100644
--- a/pyiceberg/catalog/hive.py
+++ b/pyiceberg/catalog/hive.py
@@ -74,7 +74,7 @@ from pyiceberg.io import FileIO, load_file_io
 from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
 from pyiceberg.schema import Schema, SchemaVisitor, visit
 from pyiceberg.serializers import FromInputFile
-from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, 
TableProperties, update_table_metadata
+from pyiceberg.table import CommitTableRequest, CommitTableResponse, 
PropertyUtil, Table, TableProperties, update_table_metadata
 from pyiceberg.table.metadata import new_table_metadata
 from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
 from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
@@ -95,6 +95,7 @@ from pyiceberg.types import (
     StringType,
     StructType,
     TimestampType,
+    TimestamptzType,
     TimeType,
     UUIDType,
 )
@@ -103,25 +104,13 @@ if TYPE_CHECKING:
     import pyarrow as pa
 
 
-# Replace by visitor
-hive_types = {
-    BooleanType: "boolean",
-    IntegerType: "int",
-    LongType: "bigint",
-    FloatType: "float",
-    DoubleType: "double",
-    DateType: "date",
-    TimeType: "string",
-    TimestampType: "timestamp",
-    StringType: "string",
-    UUIDType: "string",
-    BinaryType: "binary",
-    FixedType: "binary",
-}
-
 COMMENT = "comment"
 OWNER = "owner"
 
+# If set to true, HiveCatalog will operate in Hive2 compatibility mode
+HIVE2_COMPATIBLE = "hive.hive2-compatible"
+HIVE2_COMPATIBLE_DEFAULT = False
+
 
 class _HiveClient:
     """Helper class to nicely open and close the transport."""
@@ -151,10 +140,15 @@ class _HiveClient:
         self._transport.close()
 
 
-def _construct_hive_storage_descriptor(schema: Schema, location: 
Optional[str]) -> StorageDescriptor:
+def _construct_hive_storage_descriptor(
+    schema: Schema, location: Optional[str], hive2_compatible: bool = False
+) -> StorageDescriptor:
     ser_de_info = 
SerDeInfo(serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
     return StorageDescriptor(
-        [FieldSchema(field.name, visit(field.field_type, 
SchemaToHiveConverter()), field.doc) for field in schema.fields],
+        [
+            FieldSchema(field.name, visit(field.field_type, 
SchemaToHiveConverter(hive2_compatible)), field.doc)
+            for field in schema.fields
+        ],
         location,
         "org.apache.hadoop.mapred.FileInputFormat",
         "org.apache.hadoop.mapred.FileOutputFormat",
@@ -199,6 +193,7 @@ HIVE_PRIMITIVE_TYPES = {
     DateType: "date",
     TimeType: "string",
     TimestampType: "timestamp",
+    TimestamptzType: "timestamp with local time zone",
     StringType: "string",
     UUIDType: "string",
     BinaryType: "binary",
@@ -207,6 +202,11 @@ HIVE_PRIMITIVE_TYPES = {
 
 
 class SchemaToHiveConverter(SchemaVisitor[str]):
+    hive2_compatible: bool
+
+    def __init__(self, hive2_compatible: bool):
+        self.hive2_compatible = hive2_compatible
+
     def schema(self, schema: Schema, struct_result: str) -> str:
         return struct_result
 
@@ -226,6 +226,9 @@ class SchemaToHiveConverter(SchemaVisitor[str]):
     def primitive(self, primitive: PrimitiveType) -> str:
         if isinstance(primitive, DecimalType):
             return f"decimal({primitive.precision},{primitive.scale})"
+        elif self.hive2_compatible and isinstance(primitive, TimestamptzType):
+            # Hive2 doesn't support timestamp with local time zone
+            return "timestamp"
         else:
             return HIVE_PRIMITIVE_TYPES[type(primitive)]
 
@@ -314,7 +317,9 @@ class HiveCatalog(MetastoreCatalog):
             owner=properties[OWNER] if properties and OWNER in properties else 
getpass.getuser(),
             createTime=current_time_millis // 1000,
             lastAccessTime=current_time_millis // 1000,
-            sd=_construct_hive_storage_descriptor(schema, location),
+            sd=_construct_hive_storage_descriptor(
+                schema, location, 
PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, 
HIVE2_COMPATIBLE_DEFAULT)
+            ),
             tableType=EXTERNAL_TABLE,
             parameters=_construct_parameters(metadata_location),
         )
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 2dbc32d8..ac19c1a5 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -251,6 +251,12 @@ class PropertyUtil:
         else:
             return default
 
+    @staticmethod
+    def property_as_bool(properties: Dict[str, str], property_name: str, 
default: bool) -> bool:
+        if value := properties.get(property_name):
+            return value.lower() == "true"
+        return default
+
 
 class Transaction:
     _table: Table
diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py
index e59b7599..a8c904d6 100644
--- a/tests/catalog/test_hive.py
+++ b/tests/catalog/test_hive.py
@@ -61,11 +61,24 @@ from pyiceberg.table.sorting import (
 from pyiceberg.transforms import BucketTransform, IdentityTransform
 from pyiceberg.typedef import UTF8
 from pyiceberg.types import (
+    BinaryType,
     BooleanType,
+    DateType,
+    DecimalType,
+    DoubleType,
+    FixedType,
+    FloatType,
     IntegerType,
+    ListType,
     LongType,
+    MapType,
     NestedField,
     StringType,
+    StructType,
+    TimestampType,
+    TimestamptzType,
+    TimeType,
+    UUIDType,
 )
 
 HIVE_CATALOG_NAME = "hive"
@@ -181,15 +194,20 @@ def test_check_number_of_namespaces(table_schema_simple: 
Schema) -> None:
         catalog.create_table("table", schema=table_schema_simple)
 
 
[email protected]("hive2_compatible", [True, False])
 @patch("time.time", MagicMock(return_value=12345))
-def test_create_table(table_schema_simple: Schema, hive_database: 
HiveDatabase, hive_table: HiveTable) -> None:
+def test_create_table(
+    table_schema_with_all_types: Schema, hive_database: HiveDatabase, 
hive_table: HiveTable, hive2_compatible: bool
+) -> None:
     catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
+    if hive2_compatible:
+        catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, 
**{"hive.hive2-compatible": "true"})
 
     catalog._client = MagicMock()
     catalog._client.__enter__().create_table.return_value = None
     catalog._client.__enter__().get_table.return_value = hive_table
     catalog._client.__enter__().get_database.return_value = hive_database
-    catalog.create_table(("default", "table"), schema=table_schema_simple, 
properties={"owner": "javaberg"})
+    catalog.create_table(("default", "table"), 
schema=table_schema_with_all_types, properties={"owner": "javaberg"})
 
     called_hive_table: HiveTable = 
catalog._client.__enter__().create_table.call_args[0][0]
     # This one is generated within the function itself, so we need to extract
@@ -207,9 +225,27 @@ def test_create_table(table_schema_simple: Schema, 
hive_database: HiveDatabase,
             retention=None,
             sd=StorageDescriptor(
                 cols=[
-                    FieldSchema(name="foo", type="string", comment=None),
-                    FieldSchema(name="bar", type="int", comment=None),
-                    FieldSchema(name="baz", type="boolean", comment=None),
+                    FieldSchema(name='boolean', type='boolean', comment=None),
+                    FieldSchema(name='integer', type='int', comment=None),
+                    FieldSchema(name='long', type='bigint', comment=None),
+                    FieldSchema(name='float', type='float', comment=None),
+                    FieldSchema(name='double', type='double', comment=None),
+                    FieldSchema(name='decimal', type='decimal(32,3)', 
comment=None),
+                    FieldSchema(name='date', type='date', comment=None),
+                    FieldSchema(name='time', type='string', comment=None),
+                    FieldSchema(name='timestamp', type='timestamp', 
comment=None),
+                    FieldSchema(
+                        name='timestamptz',
+                        type='timestamp' if hive2_compatible else 'timestamp 
with local time zone',
+                        comment=None,
+                    ),
+                    FieldSchema(name='string', type='string', comment=None),
+                    FieldSchema(name='uuid', type='string', comment=None),
+                    FieldSchema(name='fixed', type='binary', comment=None),
+                    FieldSchema(name='binary', type='binary', comment=None),
+                    FieldSchema(name='list', type='array<string>', 
comment=None),
+                    FieldSchema(name='map', type='map<string,int>', 
comment=None),
+                    FieldSchema(name='struct', 
type='struct<inner_string:string,inner_int:int>', comment=None),
                 ],
                 location=f"{hive_database.locationUri}/table",
                 inputFormat="org.apache.hadoop.mapred.FileInputFormat",
@@ -266,12 +302,46 @@ def test_create_table(table_schema_simple: Schema, 
hive_database: HiveDatabase,
         location=metadata.location,
         table_uuid=metadata.table_uuid,
         last_updated_ms=metadata.last_updated_ms,
-        last_column_id=3,
+        last_column_id=22,
         schemas=[
             Schema(
-                NestedField(field_id=1, name="foo", field_type=StringType(), 
required=False),
-                NestedField(field_id=2, name="bar", field_type=IntegerType(), 
required=True),
-                NestedField(field_id=3, name="baz", field_type=BooleanType(), 
required=False),
+                NestedField(field_id=1, name='boolean', 
field_type=BooleanType(), required=True),
+                NestedField(field_id=2, name='integer', 
field_type=IntegerType(), required=True),
+                NestedField(field_id=3, name='long', field_type=LongType(), 
required=True),
+                NestedField(field_id=4, name='float', field_type=FloatType(), 
required=True),
+                NestedField(field_id=5, name='double', 
field_type=DoubleType(), required=True),
+                NestedField(field_id=6, name='decimal', 
field_type=DecimalType(precision=32, scale=3), required=True),
+                NestedField(field_id=7, name='date', field_type=DateType(), 
required=True),
+                NestedField(field_id=8, name='time', field_type=TimeType(), 
required=True),
+                NestedField(field_id=9, name='timestamp', 
field_type=TimestampType(), required=True),
+                NestedField(field_id=10, name='timestamptz', 
field_type=TimestamptzType(), required=True),
+                NestedField(field_id=11, name='string', 
field_type=StringType(), required=True),
+                NestedField(field_id=12, name='uuid', field_type=UUIDType(), 
required=True),
+                NestedField(field_id=13, name='fixed', 
field_type=FixedType(length=12), required=True),
+                NestedField(field_id=14, name='binary', 
field_type=BinaryType(), required=True),
+                NestedField(
+                    field_id=15,
+                    name='list',
+                    field_type=ListType(type='list', element_id=18, 
element_type=StringType(), element_required=True),
+                    required=True,
+                ),
+                NestedField(
+                    field_id=16,
+                    name='map',
+                    field_type=MapType(
+                        type='map', key_id=19, key_type=StringType(), 
value_id=20, value_type=IntegerType(), value_required=True
+                    ),
+                    required=True,
+                ),
+                NestedField(
+                    field_id=17,
+                    name='struct',
+                    field_type=StructType(
+                        NestedField(field_id=21, name='inner_string', 
field_type=StringType(), required=False),
+                        NestedField(field_id=22, name='inner_int', 
field_type=IntegerType(), required=True),
+                    ),
+                    required=True,
+                ),
                 schema_id=0,
                 identifier_field_ids=[2],
             )
diff --git a/tests/conftest.py b/tests/conftest.py
index 4a820fed..7da0a0a8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -69,7 +69,9 @@ from pyiceberg.types import (
     BinaryType,
     BooleanType,
     DateType,
+    DecimalType,
     DoubleType,
+    FixedType,
     FloatType,
     IntegerType,
     ListType,
@@ -78,6 +80,9 @@ from pyiceberg.types import (
     NestedField,
     StringType,
     StructType,
+    TimestampType,
+    TimestamptzType,
+    TimeType,
     UUIDType,
 )
 from pyiceberg.utils.datetime import datetime_to_millis
@@ -266,6 +271,54 @@ def table_schema_nested_with_struct_key_map() -> Schema:
     )
 
 
[email protected](scope="session")
+def table_schema_with_all_types() -> Schema:
+    return schema.Schema(
+        NestedField(field_id=1, name="boolean", field_type=BooleanType(), 
required=True),
+        NestedField(field_id=2, name="integer", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name="long", field_type=LongType(), 
required=True),
+        NestedField(field_id=4, name="float", field_type=FloatType(), 
required=True),
+        NestedField(field_id=5, name="double", field_type=DoubleType(), 
required=True),
+        NestedField(field_id=6, name="decimal", field_type=DecimalType(32, 3), 
required=True),
+        NestedField(field_id=7, name="date", field_type=DateType(), 
required=True),
+        NestedField(field_id=8, name="time", field_type=TimeType(), 
required=True),
+        NestedField(field_id=9, name="timestamp", field_type=TimestampType(), 
required=True),
+        NestedField(field_id=10, name="timestamptz", 
field_type=TimestamptzType(), required=True),
+        NestedField(field_id=11, name="string", field_type=StringType(), 
required=True),
+        NestedField(field_id=12, name="uuid", field_type=UUIDType(), 
required=True),
+        NestedField(field_id=14, name="fixed", field_type=FixedType(12), 
required=True),
+        NestedField(field_id=13, name="binary", field_type=BinaryType(), 
required=True),
+        NestedField(
+            field_id=15,
+            name="list",
+            field_type=ListType(element_id=16, element_type=StringType(), 
element_required=True),
+            required=True,
+        ),
+        NestedField(
+            field_id=17,
+            name="map",
+            field_type=MapType(
+                key_id=18,
+                key_type=StringType(),
+                value_id=19,
+                value_type=IntegerType(),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        NestedField(
+            field_id=20,
+            name="struct",
+            field_type=StructType(
+                NestedField(field_id=21, name="inner_string", 
field_type=StringType(), required=False),
+                NestedField(field_id=22, name="inner_int", 
field_type=IntegerType(), required=True),
+            ),
+        ),
+        schema_id=1,
+        identifier_field_ids=[2],
+    )
+
+
 @pytest.fixture(scope="session")
 def pyarrow_schema_simple_without_ids() -> "pa.Schema":
     import pyarrow as pa
@@ -1953,6 +2006,20 @@ def session_catalog() -> Catalog:
     )
 
 
[email protected](scope="session")
+def session_catalog_hive() -> Catalog:
+    return load_catalog(
+        "local",
+        **{
+            "type": "hive",
+            "uri": "http://localhost:9083";,
+            "s3.endpoint": "http://localhost:9000";,
+            "s3.access-key-id": "admin",
+            "s3.secret-access-key": "password",
+        },
+    )
+
+
 @pytest.fixture(scope="session")
 def spark() -> "SparkSession":
     import importlib.metadata
@@ -1984,6 +2051,13 @@ def spark() -> "SparkSession":
         .config("spark.sql.catalog.integration.s3.endpoint", 
"http://localhost:9000";)
         .config("spark.sql.catalog.integration.s3.path-style-access", "true")
         .config("spark.sql.defaultCatalog", "integration")
+        .config("spark.sql.catalog.hive", 
"org.apache.iceberg.spark.SparkCatalog")
+        .config("spark.sql.catalog.hive.type", "hive")
+        .config("spark.sql.catalog.hive.uri", "http://localhost:9083";)
+        .config("spark.sql.catalog.hive.io-impl", 
"org.apache.iceberg.aws.s3.S3FileIO")
+        .config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/")
+        .config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000";)
+        .config("spark.sql.catalog.hive.s3.path-style-access", "true")
         .getOrCreate()
     )
 
diff --git a/tests/integration/test_writes/test_writes.py 
b/tests/integration/test_writes/test_writes.py
index 62d3bb11..e1526d2a 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -33,6 +33,7 @@ from pyspark.sql import SparkSession
 from pytest_mock.plugin import MockerFixture
 
 from pyiceberg.catalog import Catalog
+from pyiceberg.catalog.hive import HiveCatalog
 from pyiceberg.catalog.sql import SqlCatalog
 from pyiceberg.exceptions import NoSuchTableError
 from pyiceberg.table import TableProperties, _dataframe_to_data_files
@@ -747,3 +748,22 @@ def test_write_within_transaction(spark: SparkSession, 
session_catalog: Catalog,
     tbl.transaction().set_properties({"test": "2"}).commit_transaction()
     tbl.append(arrow_table_with_null)
     assert get_metadata_entries_count(identifier) == 4
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_hive_catalog_storage_descriptor(
+    session_catalog_hive: HiveCatalog,
+    pa_schema: pa.Schema,
+    arrow_table_with_null: pa.Table,
+    spark: SparkSession,
+    format_version: int,
+) -> None:
+    tbl = _create_table(
+        session_catalog_hive, "default.test_storage_descriptor", 
{"format-version": format_version}, [arrow_table_with_null]
+    )
+
+    # check if pyiceberg can read the table
+    assert len(tbl.scan().to_arrow()) == 3
+    # check if spark can read the table
+    assert spark.sql("SELECT * FROM 
hive.default.test_storage_descriptor").count() == 3

Reply via email to