This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new cf00f6a06b Python: Refactor Avro read path to use a partner visitor
(#6506)
cf00f6a06b is described below
commit cf00f6a06b256e9c4defe226b6a37aa83c40f561
Author: Ryan Blue <[email protected]>
AuthorDate: Mon Jan 2 10:18:21 2023 -0800
Python: Refactor Avro read path to use a partner visitor (#6506)
---
python/pyiceberg/avro/file.py | 32 +++--
python/pyiceberg/avro/reader.py | 115 ++++------------
python/pyiceberg/avro/resolver.py | 271 ++++++++++++++++++++++++++-----------
python/pyiceberg/io/pyarrow.py | 6 +-
python/pyiceberg/schema.py | 102 +++++++++++++-
python/pyiceberg/typedef.py | 5 +
python/pyiceberg/types.py | 6 +
python/tests/avro/test_reader.py | 34 ++---
python/tests/avro/test_resolver.py | 4 +-
9 files changed, 371 insertions(+), 204 deletions(-)
diff --git a/python/pyiceberg/avro/file.py b/python/pyiceberg/avro/file.py
index cc725e9db7..50c85102ad 100644
--- a/python/pyiceberg/avro/file.py
+++ b/python/pyiceberg/avro/file.py
@@ -24,16 +24,21 @@ import json
from dataclasses import dataclass
from io import SEEK_SET, BufferedReader
from types import TracebackType
-from typing import Optional, Type
+from typing import (
+ Callable,
+ Dict,
+ Optional,
+ Type,
+)
from pyiceberg.avro.codecs import KNOWN_CODECS, Codec
from pyiceberg.avro.decoder import BinaryDecoder
-from pyiceberg.avro.reader import ConstructReader, Reader
-from pyiceberg.avro.resolver import resolve
+from pyiceberg.avro.reader import Reader
+from pyiceberg.avro.resolver import construct_reader, resolve
from pyiceberg.io import InputFile, InputStream
from pyiceberg.io.memory import MemoryInputStream
-from pyiceberg.schema import Schema, visit
-from pyiceberg.typedef import Record
+from pyiceberg.schema import Schema
+from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol
from pyiceberg.types import (
FixedType,
MapType,
@@ -112,6 +117,7 @@ class Block:
class AvroFile:
input_file: InputFile
read_schema: Optional[Schema]
+ read_types: Dict[int, Callable[[Schema], StructProtocol]]
input_stream: InputStream
header: AvroFileHeader
schema: Schema
@@ -120,9 +126,15 @@ class AvroFile:
decoder: BinaryDecoder
block: Optional[Block] = None
- def __init__(self, input_file: InputFile, read_schema: Optional[Schema] =
None) -> None:
+ def __init__(
+ self,
+ input_file: InputFile,
+ read_schema: Optional[Schema] = None,
+ read_types: Dict[int, Callable[[Schema], StructProtocol]] = EMPTY_DICT,
+ ) -> None:
self.input_file = input_file
self.read_schema = read_schema
+ self.read_types = read_types
def __enter__(self) -> AvroFile:
"""
@@ -137,9 +149,9 @@ class AvroFile:
self.header = self._read_header()
self.schema = self.header.get_schema()
if not self.read_schema:
- self.reader = visit(self.schema, ConstructReader())
- else:
- self.reader = resolve(self.schema, self.read_schema)
+ self.read_schema = self.schema
+
+ self.reader = resolve(self.schema, self.read_schema, self.read_types)
return self
@@ -184,6 +196,6 @@ class AvroFile:
def _read_header(self) -> AvroFileHeader:
self.input_stream.seek(0, SEEK_SET)
- reader = visit(META_SCHEMA, ConstructReader())
+ reader = construct_reader(META_SCHEMA)
_header = reader.read(self.decoder)
return AvroFileHeader(magic=_header.get(0), meta=_header.get(1),
sync=_header.get(2))
diff --git a/python/pyiceberg/avro/reader.py b/python/pyiceberg/avro/reader.py
index d1b114cc5b..f7a0194a9a 100644
--- a/python/pyiceberg/avro/reader.py
+++ b/python/pyiceberg/avro/reader.py
@@ -37,33 +37,11 @@ from typing import (
List,
Optional,
Tuple,
- Union,
)
from uuid import UUID
from pyiceberg.avro.decoder import BinaryDecoder
-from pyiceberg.schema import Schema, SchemaVisitorPerPrimitiveType
from pyiceberg.typedef import Record, StructProtocol
-from pyiceberg.types import (
- BinaryType,
- BooleanType,
- DateType,
- DecimalType,
- DoubleType,
- FixedType,
- FloatType,
- IntegerType,
- ListType,
- LongType,
- MapType,
- NestedField,
- StringType,
- StructType,
- TimestampType,
- TimestamptzType,
- TimeType,
- UUIDType,
-)
from pyiceberg.utils.singleton import Singleton
@@ -260,25 +238,43 @@ class OptionReader(Reader):
return self.option.skip(decoder)
-@dataclass(frozen=True)
-class StructReader(Reader):
- fields: Tuple[Tuple[Optional[int], Reader], ...] = dataclassfield()
+class StructProtocolReader(Reader):
+ create_struct: Callable[[], StructProtocol]
+ fields: Tuple[Tuple[Optional[int], Reader], ...]
+
+ def __init__(self, fields: Tuple[Tuple[Optional[int], Reader], ...],
create_struct: Callable[[], StructProtocol]):
+ self.create_struct = create_struct
+ self.fields = fields
+
+ def create_or_reuse(self, reuse: Optional[StructProtocol]) ->
StructProtocol:
+ if reuse:
+ return reuse
+ else:
+ return self.create_struct()
+
+ def read(self, decoder: BinaryDecoder) -> Any:
+ struct = self.create_or_reuse(None)
- def read(self, decoder: BinaryDecoder) -> Record:
- result: List[Union[Any, StructProtocol]] = [None] * len(self.fields)
for (pos, field) in self.fields:
if pos is not None:
- result[pos] = field.read(decoder)
+ struct.set(pos, field.read(decoder)) # later: pass reuse in
here
else:
field.skip(decoder)
- return Record(*result)
+ return struct
def skip(self, decoder: BinaryDecoder) -> None:
for _, field in self.fields:
field.skip(decoder)
+class StructReader(StructProtocolReader):
+ fields: Tuple[Tuple[Optional[int], Reader], ...]
+
+ def __init__(self, fields: Tuple[Tuple[Optional[int], Reader], ...]):
+ super().__init__(fields, lambda: Record.of(len(fields)))
+
+
@dataclass(frozen=True)
class ListReader(Reader):
element: Reader
@@ -325,64 +321,3 @@ class MapReader(Reader):
self.value.skip(decoder)
_skip_map_array(decoder, skip)
-
-
-class ConstructReader(SchemaVisitorPerPrimitiveType[Reader]):
- def schema(self, schema: Schema, struct_result: Reader) -> Reader:
- return struct_result
-
- def struct(self, struct: StructType, field_results: List[Reader]) ->
Reader:
- return StructReader(tuple(enumerate(field_results)))
-
- def field(self, field: NestedField, field_result: Reader) -> Reader:
- return field_result if field.required else OptionReader(field_result)
-
- def list(self, list_type: ListType, element_result: Reader) -> Reader:
- element_reader = element_result if list_type.element_required else
OptionReader(element_result)
- return ListReader(element_reader)
-
- def map(self, map_type: MapType, key_result: Reader, value_result: Reader)
-> Reader:
- value_reader = value_result if map_type.value_required else
OptionReader(value_result)
- return MapReader(key_result, value_reader)
-
- def visit_fixed(self, fixed_type: FixedType) -> Reader:
- return FixedReader(len(fixed_type))
-
- def visit_decimal(self, decimal_type: DecimalType) -> Reader:
- return DecimalReader(decimal_type.precision, decimal_type.scale)
-
- def visit_boolean(self, boolean_type: BooleanType) -> Reader:
- return BooleanReader()
-
- def visit_integer(self, integer_type: IntegerType) -> Reader:
- return IntegerReader()
-
- def visit_long(self, long_type: LongType) -> Reader:
- return IntegerReader()
-
- def visit_float(self, float_type: FloatType) -> Reader:
- return FloatReader()
-
- def visit_double(self, double_type: DoubleType) -> Reader:
- return DoubleReader()
-
- def visit_date(self, date_type: DateType) -> Reader:
- return DateReader()
-
- def visit_time(self, time_type: TimeType) -> Reader:
- return TimeReader()
-
- def visit_timestamp(self, timestamp_type: TimestampType) -> Reader:
- return TimestampReader()
-
- def visit_timestampz(self, timestamptz_type: TimestamptzType) -> Reader:
- return TimestamptzReader()
-
- def visit_string(self, string_type: StringType) -> Reader:
- return StringReader()
-
- def visit_uuid(self, uuid_type: UUIDType) -> Reader:
- return UUIDReader()
-
- def visit_binary(self, binary_ype: BinaryType) -> Reader:
- return BinaryReader()
diff --git a/python/pyiceberg/avro/resolver.py
b/python/pyiceberg/avro/resolver.py
index 5542e8de3c..a53693f415 100644
--- a/python/pyiceberg/avro/resolver.py
+++ b/python/pyiceberg/avro/resolver.py
@@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from functools import singledispatch
+# pylint: disable=arguments-renamed,unused-argument
from typing import (
+ Callable,
+ Dict,
List,
Optional,
Tuple,
@@ -23,114 +25,221 @@ from typing import (
)
from pyiceberg.avro.reader import (
- ConstructReader,
+ BinaryReader,
+ BooleanReader,
+ DateReader,
+ DecimalReader,
+ DoubleReader,
+ FixedReader,
+ FloatReader,
+ IntegerReader,
ListReader,
MapReader,
NoneReader,
OptionReader,
Reader,
+ StringReader,
StructReader,
+ TimeReader,
+ TimestampReader,
+ TimestamptzReader,
+ UUIDReader,
)
from pyiceberg.exceptions import ResolveError
-from pyiceberg.schema import Schema, promote, visit
+from pyiceberg.schema import (
+ PartnerAccessor,
+ PrimitiveWithPartnerVisitor,
+ Schema,
+ promote,
+ visit_with_partner,
+)
+from pyiceberg.typedef import EMPTY_DICT, StructProtocol
from pyiceberg.types import (
+ BinaryType,
+ BooleanType,
+ DateType,
+ DecimalType,
DoubleType,
+ FixedType,
FloatType,
IcebergType,
+ IntegerType,
ListType,
+ LongType,
MapType,
+ NestedField,
PrimitiveType,
+ StringType,
StructType,
+ TimestampType,
+ TimestamptzType,
+ TimeType,
+ UUIDType,
)
-@singledispatch
-def resolve(file_schema: Union[Schema, IcebergType], read_schema:
Union[Schema, IcebergType]) -> Reader:
- """This resolves the file and read schema
+def construct_reader(file_schema: Union[Schema, IcebergType]) -> Reader:
+ """Constructs a reader from a file schema
+
+ Args:
+ file_schema (Schema | IcebergType): The schema of the Avro file
+
+ Raises:
+ NotImplementedError: If attempting to resolve an unrecognized object
type
+ """
+ return resolve(file_schema, file_schema)
+
- The function traverses the schema in post-order fashion
+def resolve(
+ file_schema: Union[Schema, IcebergType],
+ read_schema: Union[Schema, IcebergType],
+ read_types: Dict[int, Callable[[Schema], StructProtocol]] = EMPTY_DICT,
+) -> Reader:
+ """Resolves the file and read schema to produce a reader
- Args:
- file_schema (Schema | IcebergType): The schema of the Avro file
- read_schema (Schema | IcebergType): The requested read schema which
is equal, subset or superset of the file schema
+ Args:
+ file_schema (Schema | IcebergType): The schema of the Avro file
+ read_schema (Schema | IcebergType): The requested read schema which is
equal, subset or superset of the file schema
+ read_types (Dict[int, Callable[[Schema], StructProtocol]]): A dict of
types to use for struct data
- Raises:
- NotImplementedError: If attempting to resolve an unrecognized object
type
+ Raises:
+ NotImplementedError: If attempting to resolve an unrecognized object
type
"""
- raise NotImplementedError(f"Cannot resolve non-type: {file_schema}")
+ return visit_with_partner(file_schema, read_schema,
SchemaResolver(read_types), SchemaPartnerAccessor()) # type: ignore
+
+
+class SchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]):
+ read_types: Optional[Dict[int, Callable[[Schema], StructProtocol]]]
+
+ def __init__(self, read_types: Optional[Dict[int, Callable[[Schema],
StructProtocol]]]):
+ self.read_types = read_types
+
+ def schema(self, schema: Schema, expected_schema: Optional[IcebergType],
result: Reader) -> Reader:
+ return result
+
+ def struct(self, struct: StructType, expected_struct:
Optional[IcebergType], field_readers: List[Reader]) -> Reader:
+ if not expected_struct:
+ return StructReader(tuple(enumerate(field_readers)))
+
+ if not isinstance(expected_struct, StructType):
+ raise ResolveError(f"File/read schema are not aligned for struct,
got {expected_struct}")
+
+ results: List[Tuple[Optional[int], Reader]] = []
+ expected_positions: Dict[int, int] = {field.field_id: pos for pos,
field in enumerate(expected_struct.fields)}
+
+ # first, add readers for the file fields that must be in order
+ for field, result_reader in zip(struct.fields, field_readers):
+ read_pos = expected_positions.get(field.field_id)
+ results.append((read_pos, result_reader))
+
+ file_fields = {field.field_id: field for field in struct.fields}
+ for pos, read_field in enumerate(expected_struct.fields):
+ if read_field.field_id not in file_fields:
+ if read_field.required:
+ raise ResolveError(f"{read_field} is non-optional, and not
part of the file schema")
+ # Just set the new field to None
+ results.append((pos, NoneReader()))
+
+ return StructReader(tuple(results))
+
+ def field(self, field: NestedField, expected_field: Optional[IcebergType],
field_reader: Reader) -> Reader:
+ return field_reader if field.required else OptionReader(field_reader)
+
+ def list(self, list_type: ListType, expected_list: Optional[IcebergType],
element_reader: Reader) -> Reader:
+ if expected_list and not isinstance(expected_list, ListType):
+ raise ResolveError(f"File/read schema are not aligned for list,
got {expected_list}")
+
+ return ListReader(element_reader if list_type.element_required else
OptionReader(element_reader))
+
+ def map(self, map_type: MapType, expected_map: Optional[IcebergType],
key_reader: Reader, value_reader: Reader) -> Reader:
+ if expected_map and not isinstance(expected_map, MapType):
+ raise ResolveError(f"File/read schema are not aligned for map, got
{expected_map}")
+
+ return MapReader(key_reader, value_reader if map_type.value_required
else OptionReader(value_reader))
+
+ def primitive(self, primitive: PrimitiveType, expected_primitive:
Optional[IcebergType]) -> Reader:
+ if expected_primitive is not None:
+ if not isinstance(expected_primitive, PrimitiveType):
+ raise ResolveError(f"File/read schema are not aligned for
{primitive}, got {expected_primitive}")
+ # ensure that the type can be projected to the expected
+ if primitive != expected_primitive:
+ promote(primitive, expected_primitive)
[email protected](Schema)
-def _(file_schema: Schema, read_schema: Schema) -> Reader:
- """Visit a Schema and starts resolving it by converting it to a struct"""
- return resolve(file_schema.as_struct(), read_schema.as_struct())
+ return super().primitive(primitive, expected_primitive)
+ def visit_boolean(self, boolean_type: BooleanType, partner:
Optional[IcebergType]) -> Reader:
+ return BooleanReader()
[email protected](StructType)
-def _(file_struct: StructType, read_struct: IcebergType) -> Reader:
- """Iterates over the file schema, and checks if the field is in the read
schema"""
+ def visit_integer(self, integer_type: IntegerType, partner:
Optional[IcebergType]) -> Reader:
+ return IntegerReader()
- if not isinstance(read_struct, StructType):
- raise ResolveError(f"File/read schema are not aligned for
{file_struct}, got {read_struct}")
+ def visit_long(self, long_type: LongType, partner: Optional[IcebergType])
-> Reader:
+ return IntegerReader()
- results: List[Tuple[Optional[int], Reader]] = []
- read_fields = {field.field_id: (pos, field) for pos, field in
enumerate(read_struct.fields)}
+ def visit_float(self, float_type: FloatType, partner:
Optional[IcebergType]) -> Reader:
+ return FloatReader()
- for file_field in file_struct.fields:
- if file_field.field_id in read_fields:
- read_pos, read_field = read_fields[file_field.field_id]
- result_reader = resolve(file_field.field_type,
read_field.field_type)
+ def visit_double(self, double_type: DoubleType, partner:
Optional[IcebergType]) -> Reader:
+ return DoubleReader()
+
+ def visit_decimal(self, decimal_type: DecimalType, partner:
Optional[IcebergType]) -> Reader:
+ return DecimalReader(decimal_type.precision, decimal_type.scale)
+
+ def visit_date(self, date_type: DateType, partner: Optional[IcebergType])
-> Reader:
+ return DateReader()
+
+ def visit_time(self, time_type: TimeType, partner: Optional[IcebergType])
-> Reader:
+ return TimeReader()
+
+ def visit_timestamp(self, timestamp_type: TimestampType, partner:
Optional[IcebergType]) -> Reader:
+ return TimestampReader()
+
+ def visit_timestampz(self, timestamptz_type: TimestamptzType, partner:
Optional[IcebergType]) -> Reader:
+ return TimestamptzReader()
+
+ def visit_string(self, string_type: StringType, partner:
Optional[IcebergType]) -> Reader:
+ return StringReader()
+
+ def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType])
-> Reader:
+ return UUIDReader()
+
+ def visit_fixed(self, fixed_type: FixedType, partner:
Optional[IcebergType]) -> Reader:
+ return FixedReader(len(fixed_type))
+
+ def visit_binary(self, binary_type: BinaryType, partner:
Optional[IcebergType]) -> Reader:
+ return BinaryReader()
+
+
+class SchemaPartnerAccessor(PartnerAccessor[IcebergType]):
+ def schema_partner(self, partner: Optional[IcebergType]) ->
Optional[IcebergType]:
+ if isinstance(partner, Schema):
+ return partner.as_struct()
+
+ raise ResolveError(f"File/read schema are not aligned for schema, got
{partner}")
+
+ def field_partner(self, partner: Optional[IcebergType], field_id: int,
field_name: str) -> Optional[IcebergType]:
+ if isinstance(partner, StructType):
+ field = partner.field(field_id)
else:
- read_pos = None
- result_reader = visit(file_field.field_type, ConstructReader())
- result_reader = result_reader if file_field.required else
OptionReader(result_reader)
- results.append((read_pos, result_reader))
-
- file_fields = {field.field_id: field for field in file_struct.fields}
- for pos, read_field in enumerate(read_struct.fields):
- if read_field.field_id not in file_fields:
- if read_field.required:
- raise ResolveError(f"{read_field} is non-optional, and not
part of the file schema")
- # Just set the new field to None
- results.append((pos, NoneReader()))
-
- return StructReader(tuple(results))
-
-
[email protected](ListType)
-def _(file_list: ListType, read_list: IcebergType) -> Reader:
- if not isinstance(read_list, ListType):
- raise ResolveError(f"File/read schema are not aligned for {file_list},
got {read_list}")
- element_reader = resolve(file_list.element_type, read_list.element_type)
- return ListReader(element_reader)
-
-
[email protected](MapType)
-def _(file_map: MapType, read_map: IcebergType) -> Reader:
- if not isinstance(read_map, MapType):
- raise ResolveError(f"File/read schema are not aligned for {file_map},
got {read_map}")
- key_reader = resolve(file_map.key_type, read_map.key_type)
- value_reader = resolve(file_map.value_type, read_map.value_type)
-
- return MapReader(key_reader, value_reader)
-
-
[email protected](FloatType)
-def _(file_type: PrimitiveType, read_type: IcebergType) -> Reader:
- """This is a special case, when we need to adhere to the bytes written"""
- if isinstance(read_type, DoubleType):
- return visit(file_type, ConstructReader())
- else:
- raise ResolveError(f"Cannot promote an float to {read_type}")
-
-
[email protected](PrimitiveType)
-def _(file_type: PrimitiveType, read_type: IcebergType) -> Reader:
- """Converting the primitive type into an actual reader that will decode
the physical data"""
- if not isinstance(read_type, PrimitiveType):
- raise ResolveError(f"Cannot promote {file_type} to {read_type}")
-
- # In the case of a promotion, we want to check if it is valid
- if file_type != read_type:
- read_type = promote(file_type, read_type)
- return visit(read_type, ConstructReader())
+ raise ResolveError(f"File/read schema are not aligned for struct,
got {partner}")
+
+ return field.field_type if field else None
+
+ def list_element_partner(self, partner_list: Optional[IcebergType]) ->
Optional[IcebergType]:
+ if isinstance(partner_list, ListType):
+ return partner_list.element_type
+
+ raise ResolveError(f"File/read schema are not aligned for list, got
{partner_list}")
+
+ def map_key_partner(self, partner_map: Optional[IcebergType]) ->
Optional[IcebergType]:
+ if isinstance(partner_map, MapType):
+ return partner_map.key_type
+
+ raise ResolveError(f"File/read schema are not aligned for map, got
{partner_map}")
+
+ def map_value_partner(self, partner_map: Optional[IcebergType]) ->
Optional[IcebergType]:
+ if isinstance(partner_map, MapType):
+ return partner_map.value_type
+
+ raise ResolveError(f"File/read schema are not aligned for map, got
{partner_map}")
diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py
index b4c9024f2d..adddff82fa 100644
--- a/python/pyiceberg/io/pyarrow.py
+++ b/python/pyiceberg/io/pyarrow.py
@@ -51,7 +51,7 @@ from pyarrow.fs import (
S3FileSystem,
)
-from pyiceberg.avro.resolver import ResolveError, promote
+from pyiceberg.avro.resolver import ResolveError
from pyiceberg.expressions import (
AlwaysTrue,
BooleanExpression,
@@ -77,6 +77,7 @@ from pyiceberg.schema import (
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
+ promote,
prune_columns,
visit,
visit_with_partner,
@@ -605,6 +606,9 @@ class ArrowAccessor(PartnerAccessor[pa.Array]):
def __init__(self, file_schema: Schema):
self.file_schema = file_schema
+ def schema_partner(self, partner: Optional[pa.Array]) ->
Optional[pa.Array]:
+ return partner
+
def field_partner(self, partner_struct: Optional[pa.Array], field_id: int,
_: str) -> Optional[pa.Array]:
if partner_struct:
# use the field name from the file schema
diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index 33f5cf3dc0..1a85cb4baa 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -394,7 +394,102 @@ class SchemaWithPartnerVisitor(Generic[P, T], ABC):
"""Visit a primitive type with a partner"""
+class PrimitiveWithPartnerVisitor(SchemaWithPartnerVisitor[P, T]):
+ def primitive(self, primitive: PrimitiveType, primitive_partner:
Optional[P]) -> T:
+ """Visit a PrimitiveType"""
+ if isinstance(primitive, BooleanType):
+ return self.visit_boolean(primitive, primitive_partner)
+ elif isinstance(primitive, IntegerType):
+ return self.visit_integer(primitive, primitive_partner)
+ elif isinstance(primitive, LongType):
+ return self.visit_long(primitive, primitive_partner)
+ elif isinstance(primitive, FloatType):
+ return self.visit_float(primitive, primitive_partner)
+ elif isinstance(primitive, DoubleType):
+ return self.visit_double(primitive, primitive_partner)
+ elif isinstance(primitive, DecimalType):
+ return self.visit_decimal(primitive, primitive_partner)
+ elif isinstance(primitive, DateType):
+ return self.visit_date(primitive, primitive_partner)
+ elif isinstance(primitive, TimeType):
+ return self.visit_time(primitive, primitive_partner)
+ elif isinstance(primitive, TimestampType):
+ return self.visit_timestamp(primitive, primitive_partner)
+ elif isinstance(primitive, TimestamptzType):
+ return self.visit_timestampz(primitive, primitive_partner)
+ elif isinstance(primitive, StringType):
+ return self.visit_string(primitive, primitive_partner)
+ elif isinstance(primitive, UUIDType):
+ return self.visit_uuid(primitive, primitive_partner)
+ elif isinstance(primitive, FixedType):
+ return self.visit_fixed(primitive, primitive_partner)
+ elif isinstance(primitive, BinaryType):
+ return self.visit_binary(primitive, primitive_partner)
+ else:
+ raise ValueError(f"Unknown type: {primitive}")
+
+ @abstractmethod
+ def visit_boolean(self, boolean_type: BooleanType, partner: Optional[P])
-> T:
+ """Visit a BooleanType"""
+
+ @abstractmethod
+ def visit_integer(self, integer_type: IntegerType, partner: Optional[P])
-> T:
+ """Visit a IntegerType"""
+
+ @abstractmethod
+ def visit_long(self, long_type: LongType, partner: Optional[P]) -> T:
+ """Visit a LongType"""
+
+ @abstractmethod
+ def visit_float(self, float_type: FloatType, partner: Optional[P]) -> T:
+ """Visit a FloatType"""
+
+ @abstractmethod
+ def visit_double(self, double_type: DoubleType, partner: Optional[P]) -> T:
+ """Visit a DoubleType"""
+
+ @abstractmethod
+ def visit_decimal(self, decimal_type: DecimalType, partner: Optional[P])
-> T:
+ """Visit a DecimalType"""
+
+ @abstractmethod
+ def visit_date(self, date_type: DateType, partner: Optional[P]) -> T:
+ """Visit a DecimalType"""
+
+ @abstractmethod
+ def visit_time(self, time_type: TimeType, partner: Optional[P]) -> T:
+ """Visit a DecimalType"""
+
+ @abstractmethod
+ def visit_timestamp(self, timestamp_type: TimestampType, partner:
Optional[P]) -> T:
+ """Visit a TimestampType"""
+
+ @abstractmethod
+ def visit_timestampz(self, timestamptz_type: TimestamptzType, partner:
Optional[P]) -> T:
+ """Visit a TimestamptzType"""
+
+ @abstractmethod
+ def visit_string(self, string_type: StringType, partner: Optional[P]) -> T:
+ """Visit a StringType"""
+
+ @abstractmethod
+ def visit_uuid(self, uuid_type: UUIDType, partner: Optional[P]) -> T:
+ """Visit a UUIDType"""
+
+ @abstractmethod
+ def visit_fixed(self, fixed_type: FixedType, partner: Optional[P]) -> T:
+ """Visit a FixedType"""
+
+ @abstractmethod
+ def visit_binary(self, binary_type: BinaryType, partner: Optional[P]) -> T:
+ """Visit a BinaryType"""
+
+
class PartnerAccessor(Generic[P], ABC):
+ @abstractmethod
+ def schema_partner(self, partner: Optional[P]) -> Optional[P]:
+ """Returns the equivalent of the schema as a struct"""
+
@abstractmethod
def field_partner(self, partner_struct: Optional[P], field_id: int,
field_name: str) -> Optional[P]:
"""Returns the equivalent struct field by name or id in the partner
struct"""
@@ -416,12 +511,13 @@ class PartnerAccessor(Generic[P], ABC):
def visit_with_partner(
schema_or_type: Union[Schema, IcebergType], partner: P, visitor:
SchemaWithPartnerVisitor[T, P], accessor: PartnerAccessor[P]
) -> T:
- raise ValueError(f"Unsupported type: {type}")
+ raise ValueError(f"Unsupported type: {schema_or_type}")
@visit_with_partner.register(Schema)
def _(schema: Schema, partner: P, visitor: SchemaWithPartnerVisitor[P, T],
accessor: PartnerAccessor[P]) -> T:
- return visitor.schema(schema, partner,
visit_with_partner(schema.as_struct(), partner, visitor, accessor)) # type:
ignore
+ struct_partner = accessor.schema_partner(partner)
+ return visitor.schema(schema, partner,
visit_with_partner(schema.as_struct(), struct_partner, visitor, accessor)) #
type: ignore
@visit_with_partner.register(StructType)
@@ -561,7 +657,7 @@ class SchemaVisitorPerPrimitiveType(SchemaVisitor[T], ABC):
"""Visit a UUIDType"""
@abstractmethod
- def visit_binary(self, binary_ype: BinaryType) -> T:
+ def visit_binary(self, binary_type: BinaryType) -> T:
"""Visit a BinaryType"""
diff --git a/python/pyiceberg/typedef.py b/python/pyiceberg/typedef.py
index c924d703bd..228b9a927c 100644
--- a/python/pyiceberg/typedef.py
+++ b/python/pyiceberg/typedef.py
@@ -83,10 +83,15 @@ class StructProtocol(Protocol): # pragma: no cover
class Record(StructProtocol):
_data: List[Union[Any, StructProtocol]]
+ @staticmethod
+ def of(num_fields: int) -> Record:
+ return Record(*([None] * num_fields))
+
def __init__(self, *data: Union[Any, StructProtocol]) -> None:
self._data = list(data)
def set(self, pos: int, value: Any) -> None:
+ print(f"set({pos}, {repr(value)})")
self._data[pos] = value
def get(self, pos: int) -> Any:
diff --git a/python/pyiceberg/types.py b/python/pyiceberg/types.py
index 6236358063..6a736ed19d 100644
--- a/python/pyiceberg/types.py
+++ b/python/pyiceberg/types.py
@@ -268,6 +268,12 @@ class StructType(IcebergType):
data["fields"] = fields
super().__init__(**data)
+ def field(self, field_id: int) -> Optional[NestedField]:
+ for field in self.fields:
+ if field.field_id == field_id:
+ return field
+ return None
+
def __str__(self) -> str:
return f"struct<{', '.join(map(str, self.fields))}>"
diff --git a/python/tests/avro/test_reader.py b/python/tests/avro/test_reader.py
index 2d54b6e887..3fc99ce3f9 100644
--- a/python/tests/avro/test_reader.py
+++ b/python/tests/avro/test_reader.py
@@ -23,7 +23,6 @@ from pyiceberg.avro.file import AvroFile
from pyiceberg.avro.reader import (
BinaryReader,
BooleanReader,
- ConstructReader,
DateReader,
DecimalReader,
DoubleReader,
@@ -36,9 +35,10 @@ from pyiceberg.avro.reader import (
TimestamptzReader,
UUIDReader,
)
+from pyiceberg.avro.resolver import construct_reader
from pyiceberg.io.pyarrow import PyArrowFileIO
from pyiceberg.manifest import _convert_pos_to_dict
-from pyiceberg.schema import Schema, visit
+from pyiceberg.schema import Schema
from pyiceberg.typedef import Record
from pyiceberg.types import (
BinaryType,
@@ -446,55 +446,55 @@ def test_null_struct_convert_pos_to_dict() -> None:
def test_fixed_reader() -> None:
- assert visit(FixedType(22), ConstructReader()) == FixedReader(22)
+ assert construct_reader(FixedType(22)) == FixedReader(22)
def test_decimal_reader() -> None:
- assert visit(DecimalType(19, 25), ConstructReader()) == DecimalReader(19,
25)
+ assert construct_reader(DecimalType(19, 25)) == DecimalReader(19, 25)
def test_boolean_reader() -> None:
- assert visit(BooleanType(), ConstructReader()) == BooleanReader()
+ assert construct_reader(BooleanType()) == BooleanReader()
def test_integer_reader() -> None:
- assert visit(IntegerType(), ConstructReader()) == IntegerReader()
+ assert construct_reader(IntegerType()) == IntegerReader()
def test_long_reader() -> None:
- assert visit(LongType(), ConstructReader()) == IntegerReader()
+ assert construct_reader(LongType()) == IntegerReader()
def test_float_reader() -> None:
- assert visit(FloatType(), ConstructReader()) == FloatReader()
+ assert construct_reader(FloatType()) == FloatReader()
def test_double_reader() -> None:
- assert visit(DoubleType(), ConstructReader()) == DoubleReader()
+ assert construct_reader(DoubleType()) == DoubleReader()
def test_date_reader() -> None:
- assert visit(DateType(), ConstructReader()) == DateReader()
+ assert construct_reader(DateType()) == DateReader()
def test_time_reader() -> None:
- assert visit(TimeType(), ConstructReader()) == TimeReader()
+ assert construct_reader(TimeType()) == TimeReader()
def test_timestamp_reader() -> None:
- assert visit(TimestampType(), ConstructReader()) == TimestampReader()
+ assert construct_reader(TimestampType()) == TimestampReader()
def test_timestamptz_reader() -> None:
- assert visit(TimestamptzType(), ConstructReader()) == TimestamptzReader()
+ assert construct_reader(TimestamptzType()) == TimestamptzReader()
def test_string_reader() -> None:
- assert visit(StringType(), ConstructReader()) == StringReader()
+ assert construct_reader(StringType()) == StringReader()
def test_binary_reader() -> None:
- assert visit(BinaryType(), ConstructReader()) == BinaryReader()
+ assert construct_reader(BinaryType()) == BinaryReader()
def test_unknown_type() -> None:
@@ -502,10 +502,10 @@ def test_unknown_type() -> None:
__root__ = "UnknownType"
with pytest.raises(ValueError) as exc_info:
- visit(UnknownType(), ConstructReader())
+ construct_reader(UnknownType())
assert "Unknown type:" in str(exc_info.value)
def test_uuid_reader() -> None:
- assert visit(UUIDType(), ConstructReader()) == UUIDReader()
+ assert construct_reader(UUIDType()) == UUIDReader()
diff --git a/python/tests/avro/test_resolver.py
b/python/tests/avro/test_resolver.py
index f051881e0d..c36b76922a 100644
--- a/python/tests/avro/test_resolver.py
+++ b/python/tests/avro/test_resolver.py
@@ -160,7 +160,7 @@ def test_resolver_change_type() -> None:
with pytest.raises(ResolveError) as exc_info:
resolve(write_schema, read_schema)
- assert "File/read schema are not aligned for list<string>, got map<string,
string>" in str(exc_info.value)
+ assert "File/read schema are not aligned for list, got map<string,
string>" in str(exc_info.value)
def test_resolve_int_to_long() -> None:
@@ -174,7 +174,7 @@ def test_resolve_float_to_double() -> None:
def test_resolve_decimal_to_decimal() -> None:
# DecimalType(P, S) to DecimalType(P2, S) where P2 > P
- assert resolve(DecimalType(19, 25), DecimalType(22, 25)) ==
DecimalReader(22, 25)
+ assert resolve(DecimalType(19, 25), DecimalType(22, 25)) ==
DecimalReader(19, 25)
def test_struct_not_aligned() -> None: