This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new ba9ff983 Add typealias for table version (#566)
ba9ff983 is described below

commit ba9ff98373e25bed7d118911942bb4315e3b949a
Author: Mehul Batra <[email protected]>
AuthorDate: Wed Apr 3 17:50:23 2024 +0530

    Add typealias for table version (#566)
    
    * typealias for table version
    
    * typealias for table version
    
    * typealias for table version
    
    * typealias for table version
    
    * typealias for table version
    
    * typealias for table version replaced in all files
---
 pyiceberg/manifest.py        | 24 ++++++++++++------------
 pyiceberg/table/__init__.py  |  5 +++--
 pyiceberg/typedef.py         |  5 +++++
 tests/utils/test_manifest.py |  8 ++++----
 4 files changed, 24 insertions(+), 18 deletions(-)

diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py
index 03dc3199..5277eed9 100644
--- a/pyiceberg/manifest.py
+++ b/pyiceberg/manifest.py
@@ -37,7 +37,7 @@ from pyiceberg.exceptions import ValidationError
 from pyiceberg.io import FileIO, InputFile, OutputFile
 from pyiceberg.partitioning import PartitionSpec
 from pyiceberg.schema import Schema
-from pyiceberg.typedef import EMPTY_DICT, Record
+from pyiceberg.typedef import EMPTY_DICT, Record, TableVersion
 from pyiceberg.types import (
     BinaryType,
     BooleanType,
@@ -302,7 +302,7 @@ def _(partition_field_type: PrimitiveType) -> PrimitiveType:
     return partition_field_type
 
 
-def data_file_with_partition(partition_type: StructType, format_version: 
Literal[1, 2]) -> StructType:
+def data_file_with_partition(partition_type: StructType, format_version: 
TableVersion) -> StructType:
     data_file_partition_type = StructType(*[
         NestedField(
             field_id=field.field_id,
@@ -372,7 +372,7 @@ class DataFile(Record):
             value = FileFormat[value]
         super().__setattr__(name, value)
 
-    def __init__(self, format_version: Literal[1, 2] = DEFAULT_READ_VERSION, 
*data: Any, **named_data: Any) -> None:
+    def __init__(self, format_version: TableVersion = DEFAULT_READ_VERSION, 
*data: Any, **named_data: Any) -> None:
         super().__init__(
             *data,
             **{"struct": DATA_FILE_TYPE[format_version], **named_data},
@@ -408,7 +408,7 @@ MANIFEST_ENTRY_SCHEMAS = {
 MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for 
format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()}
 
 
-def manifest_entry_schema_with_data_file(format_version: Literal[1, 2], 
data_file: StructType) -> Schema:
+def manifest_entry_schema_with_data_file(format_version: TableVersion, 
data_file: StructType) -> Schema:
     return Schema(*[
         NestedField(2, "data_file", data_file, required=True) if 
field.field_id == 2 else field
         for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields
@@ -719,9 +719,9 @@ class ManifestWriter(ABC):
 
     @property
     @abstractmethod
-    def version(self) -> Literal[1, 2]: ...
+    def version(self) -> TableVersion: ...
 
-    def _with_partition(self, format_version: Literal[1, 2]) -> Schema:
+    def _with_partition(self, format_version: TableVersion) -> Schema:
         data_file_type = data_file_with_partition(
             format_version=format_version, 
partition_type=self._spec.partition_type(self._schema)
         )
@@ -807,7 +807,7 @@ class ManifestWriterV1(ManifestWriter):
         return ManifestContent.DATA
 
     @property
-    def version(self) -> Literal[1, 2]:
+    def version(self) -> TableVersion:
         return 1
 
     def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
@@ -834,7 +834,7 @@ class ManifestWriterV2(ManifestWriter):
         return ManifestContent.DATA
 
     @property
-    def version(self) -> Literal[1, 2]:
+    def version(self) -> TableVersion:
         return 2
 
     def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
@@ -847,7 +847,7 @@ class ManifestWriterV2(ManifestWriter):
 
 
 def write_manifest(
-    format_version: Literal[1, 2], spec: PartitionSpec, schema: Schema, 
output_file: OutputFile, snapshot_id: int
+    format_version: TableVersion, spec: PartitionSpec, schema: Schema, 
output_file: OutputFile, snapshot_id: int
 ) -> ManifestWriter:
     if format_version == 1:
         return ManifestWriterV1(spec, schema, output_file, snapshot_id)
@@ -858,14 +858,14 @@ def write_manifest(
 
 
 class ManifestListWriter(ABC):
-    _format_version: Literal[1, 2]
+    _format_version: TableVersion
     _output_file: OutputFile
     _meta: Dict[str, str]
     _manifest_files: List[ManifestFile]
     _commit_snapshot_id: int
     _writer: AvroOutputFile[ManifestFile]
 
-    def __init__(self, format_version: Literal[1, 2], output_file: OutputFile, 
meta: Dict[str, Any]):
+    def __init__(self, format_version: TableVersion, output_file: OutputFile, 
meta: Dict[str, Any]):
         self._format_version = format_version
         self._output_file = output_file
         self._meta = meta
@@ -957,7 +957,7 @@ class ManifestListWriterV2(ManifestListWriter):
 
 
 def write_manifest_list(
-    format_version: Literal[1, 2],
+    format_version: TableVersion,
     output_file: OutputFile,
     snapshot_id: int,
     parent_snapshot_id: Optional[int],
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 787bdb86..5f67c05c 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -121,6 +121,7 @@ from pyiceberg.typedef import (
     KeyDefaultDict,
     Properties,
     Record,
+    TableVersion,
 )
 from pyiceberg.types import (
     IcebergType,
@@ -293,7 +294,7 @@ class Transaction:
 
         return self
 
-    def upgrade_table_version(self, format_version: Literal[1, 2]) -> 
Transaction:
+    def upgrade_table_version(self, format_version: TableVersion) -> 
Transaction:
         """Set the table to a certain version.
 
         Args:
@@ -1023,7 +1024,7 @@ class Table:
         )
 
     @property
-    def format_version(self) -> Literal[1, 2]:
+    def format_version(self) -> TableVersion:
         return self.metadata.format_version
 
     def schema(self) -> Schema:
diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py
index e57bf349..4bed386c 100644
--- a/pyiceberg/typedef.py
+++ b/pyiceberg/typedef.py
@@ -26,6 +26,7 @@ from typing import (
     Dict,
     Generic,
     List,
+    Literal,
     Optional,
     Protocol,
     Set,
@@ -37,6 +38,7 @@ from typing import (
 from uuid import UUID
 
 from pydantic import BaseModel, ConfigDict, RootModel
+from typing_extensions import TypeAlias
 
 if TYPE_CHECKING:
     from pyiceberg.types import StructType
@@ -199,3 +201,6 @@ class Record(StructProtocol):
     def record_fields(self) -> List[str]:
         """Return values of all the fields of the Record class except those 
specified in skip_fields."""
         return [self.__getattribute__(v) if hasattr(self, v) else None for v 
in self._position_to_field_name]
+
+
+TableVersion: TypeAlias = Literal[1, 2]
diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py
index 3e789cb8..8bb03cd8 100644
--- a/tests/utils/test_manifest.py
+++ b/tests/utils/test_manifest.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=redefined-outer-name,arguments-renamed,fixme
 from tempfile import TemporaryDirectory
-from typing import Dict, Literal
+from typing import Dict
 
 import fastavro
 import pytest
@@ -39,7 +39,7 @@ from pyiceberg.partitioning import PartitionField, 
PartitionSpec
 from pyiceberg.schema import Schema
 from pyiceberg.table.snapshots import Operation, Snapshot, Summary
 from pyiceberg.transforms import IdentityTransform
-from pyiceberg.typedef import Record
+from pyiceberg.typedef import Record, TableVersion
 from pyiceberg.types import IntegerType, NestedField
 
 
@@ -308,7 +308,7 @@ def test_read_manifest_v2(generated_manifest_file_file_v2: 
str) -> None:
 
 @pytest.mark.parametrize("format_version", [1, 2])
 def test_write_manifest(
-    generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: 
str, format_version: Literal[1, 2]
+    generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: 
str, format_version: TableVersion
 ) -> None:
     io = load_file_io()
     snapshot = Snapshot(
@@ -478,7 +478,7 @@ def test_write_manifest(
 
 @pytest.mark.parametrize("format_version", [1, 2])
 def test_write_manifest_list(
-    generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: 
str, format_version: Literal[1, 2]
+    generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: 
str, format_version: TableVersion
 ) -> None:
     io = load_file_io()
 

Reply via email to