This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new b5b916bd90 Python: Improve AVRO reading speed (#8084)
b5b916bd90 is described below
commit b5b916bd90d5fd60e736fade47874b52ca30ab39
Author: Rusty Conover <[email protected]>
AuthorDate: Tue Jul 25 02:33:55 2023 -0400
Python: Improve AVRO reading speed (#8084)
* Python: Improve Avro read performance
Utilize __slots__ for various Avro classes.
Memoize the key/value reader functions in MapReader.
* Python: improve Avro parsing by caching struct types.
Cache the struct types for ManifestEntry and ManifestFile schemas.
Cache the _position_to_field_name lookup in Records if a struct type is
passed.
Improve StructReader to use slots and avoid isinstance() lookups on each
time
the structure is read, by performing the check in __init__().
Improve StructReader to determine the arguments to create_struct() once in
init
and reuse the result each time the structure is read, avoid a try block.
* Python: Improve Avro Parse speed
Cache the hash value of the Struct types so that when new
Record objects are created time is not spent recomputing
the hash value.
Change the implementation of the Record class to cache the
mapping from field name to position. Also utilize an array
rather than a dictionary for these lookups.
* fix some lints
* fix some type casing
* fix: adjust PR comments
* fix: address PR feedback.
* Add additional tests for InMemoryBinaryDecoder
* Load entire Avro file into memory rather than streaming it
* fix: disable seeking on input file
* Update python/pyiceberg/avro/file.py
---------
Co-authored-by: Fokko Driesprong <[email protected]>
---
python/pyiceberg/avro/decoder.py | 192 ++++++++++++++++++++++++++++++++------
python/pyiceberg/avro/file.py | 20 ++--
python/pyiceberg/avro/reader.py | 114 ++++++++++++++++------
python/pyiceberg/manifest.py | 8 +-
python/tests/avro/test_decoder.py | 140 ++++++++++++++++-----------
python/tests/avro/test_reader.py | 26 ++++--
6 files changed, 362 insertions(+), 138 deletions(-)
diff --git a/python/pyiceberg/avro/decoder.py b/python/pyiceberg/avro/decoder.py
index e680c00bcd..35a1651192 100644
--- a/python/pyiceberg/avro/decoder.py
+++ b/python/pyiceberg/avro/decoder.py
@@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import decimal
+from abc import ABC, abstractmethod
from datetime import datetime, time
from io import SEEK_CUR
-from typing import List
+from typing import Dict, List
from uuid import UUID
from pyiceberg.avro import STRUCT_DOUBLE, STRUCT_FLOAT
@@ -26,39 +27,24 @@ from pyiceberg.utils.datetime import micros_to_time,
micros_to_timestamp, micros
from pyiceberg.utils.decimal import unscaled_to_decimal
-class BinaryDecoder:
+class BinaryDecoder(ABC):
"""Read leaf values."""
- __slots__ = "_input_stream"
- _input_stream: InputStream
-
+ @abstractmethod
def __init__(self, input_stream: InputStream) -> None:
- """Reader is a Python object on which we can call read, seek, and
tell."""
- self._input_stream = input_stream
+ """Create the decoder."""
+
+ @abstractmethod
+ def tell(self) -> int:
+ """Return the current position."""
+ @abstractmethod
def read(self, n: int) -> bytes:
"""Read n bytes."""
- if n < 0:
- raise ValueError(f"Requested {n} bytes to read, expected positive
integer.")
- data: List[bytes] = []
-
- n_remaining = n
- while n_remaining > 0:
- data_read = self._input_stream.read(n_remaining)
- read_len = len(data_read)
- if read_len == n:
- # If we read everything, we return directly
- # otherwise we'll continue to fetch the rest
- return data_read
- elif read_len <= 0:
- raise EOFError(f"EOF: read {read_len} bytes")
- data.append(data_read)
- n_remaining -= read_len
-
- return b"".join(data)
+ @abstractmethod
def skip(self, n: int) -> None:
- self._input_stream.seek(n, SEEK_CUR)
+ """Skip n bytes."""
def read_boolean(self) -> bool:
"""Reads a value from the stream as a boolean.
@@ -69,7 +55,7 @@ class BinaryDecoder:
return ord(self.read(1)) == 1
def read_int(self) -> int:
- """Reads a value from the stream as an integer.
+ """Reads an int/long value.
int/long values are written using variable-length, zigzag coding.
"""
@@ -83,6 +69,25 @@ class BinaryDecoder:
datum = (n >> 1) ^ -(n & 1)
return datum
+ def read_ints(self, n: int, dest: List[int]) -> None:
+ """Reads a list of integers."""
+ for _ in range(n):
+ dest.append(self.read_int())
+
+ def read_int_int_dict(self, n: int, dest: Dict[int, int]) -> None:
+ """Reads a dictionary of integers for keys and values into a
destination dictionary."""
+ for _ in range(n):
+ k = self.read_int()
+ v = self.read_int()
+ dest[k] = v
+
+ def read_int_bytes_dict(self, n: int, dest: Dict[int, bytes]) -> None:
+ """Reads a dictionary of integers for keys and bytes for values into a
destination dictionary."""
+ for _ in range(n):
+ k = self.read_int()
+ v = self.read_bytes()
+ dest[k] = v
+
def read_float(self) -> float:
"""Reads a value from the stream as a float.
@@ -191,3 +196,136 @@ class BinaryDecoder:
def skip_utf8(self) -> None:
self.skip_bytes()
+
+
+class StreamingBinaryDecoder(BinaryDecoder):
+ """Read leaf values."""
+
+ __slots__ = "_input_stream"
+ _input_stream: InputStream
+
+ def __init__(self, input_stream: InputStream) -> None:
+ """Reader is a Python object on which we can call read, seek, and
tell."""
+ super().__init__(input_stream)
+ self._input_stream = input_stream
+
+ def tell(self) -> int:
+ """Return the current stream position."""
+ return self._input_stream.tell()
+
+ def read(self, n: int) -> bytes:
+ """Read n bytes."""
+ if n < 0:
+ raise ValueError(f"Requested {n} bytes to read, expected positive
integer.")
+ data: List[bytes] = []
+
+ n_remaining = n
+ while n_remaining > 0:
+ data_read = self._input_stream.read(n_remaining)
+ read_len = len(data_read)
+ if read_len == n:
+ # If we read everything, we return directly
+ # otherwise we'll continue to fetch the rest
+ return data_read
+ elif read_len <= 0:
+ raise EOFError(f"EOF: read {read_len} bytes")
+ data.append(data_read)
+ n_remaining -= read_len
+
+ return b"".join(data)
+
+ def skip(self, n: int) -> None:
+ self._input_stream.seek(n, SEEK_CUR)
+
+
+class InMemoryBinaryDecoder(BinaryDecoder):
+ """Implement a BinaryDecoder that reads from an in-memory buffer.
+
+ This may be more efficient if the entire block is already in memory
+ as it does not need to interact with the I/O subsystem.
+ """
+
+ __slots__ = ["_contents", "_position", "_size"]
+ _contents: bytes
+ _position: int
+ _size: int
+
+ def __init__(self, input_stream: InputStream) -> None:
+ """Reader is a Python object on which we can call read, seek, and
tell."""
+ super().__init__(input_stream)
+ self._contents = input_stream.read()
+ self._size = len(self._contents)
+ self._position = 0
+
+ def tell(self) -> int:
+ """Return the current stream position."""
+ return self._position
+
+ def read(self, n: int) -> bytes:
+ """Read n bytes."""
+ if n < 0:
+ raise ValueError(f"Requested {n} bytes to read, expected positive
integer.")
+ if self._position + n > self._size:
+ raise EOFError(f"EOF: read {n} bytes")
+ r = self._contents[self._position : self._position + n]
+ self._position += n
+ return r
+
+ def skip(self, n: int) -> None:
+ self._position += n
+
+ def read_boolean(self) -> bool:
+ """Reads a value from the stream as a boolean.
+
+ A boolean is written as a single byte
+ whose value is either 0 (false) or 1 (true).
+ """
+ r = self._contents[self._position]
+ self._position += 1
+ return r != 0
+
+ def read_int(self) -> int:
+ """Reads a value from the stream as an integer.
+
+ int/long values are written using variable-length, zigzag coding.
+ """
+ if self._position == self._size:
+ raise EOFError("EOF: read 1 byte")
+ b = self._contents[self._position]
+ self._position += 1
+ n = b & 0x7F
+ shift = 7
+ while b & 0x80:
+ b = self._contents[self._position]
+ self._position += 1
+ n |= (b & 0x7F) << shift
+ shift += 7
+ return (n >> 1) ^ -(n & 1)
+
+ def read_int_bytes_dict(self, n: int, dest: Dict[int, bytes]) -> None:
+ """Reads a dictionary of integers for keys and bytes for values into a
destination dict."""
+ for _ in range(n):
+ k = self.read_int()
+
+ byte_length = self.read_int()
+ if byte_length <= 0:
+ dest[k] = b""
+ else:
+ dest[k] = self._contents[self._position : self._position +
byte_length]
+ self._position += byte_length
+
+ def read_bytes(self) -> bytes:
+ """Bytes are encoded as a long followed by that many bytes of data."""
+ num_bytes = self.read_int()
+ if num_bytes <= 0:
+ return b""
+ r = self._contents[self._position : self._position + num_bytes]
+ self._position += num_bytes
+ return r
+
+ def skip_int(self) -> None:
+ b = self._contents[self._position]
+ self._position += 1
+ while b & 0x80:
+ b = self._contents[self._position]
+ self._position += 1
diff --git a/python/pyiceberg/avro/file.py b/python/pyiceberg/avro/file.py
index 520eda9ce1..f780a8a30c 100644
--- a/python/pyiceberg/avro/file.py
+++ b/python/pyiceberg/avro/file.py
@@ -35,17 +35,12 @@ from typing import (
)
from pyiceberg.avro.codecs import KNOWN_CODECS, Codec
-from pyiceberg.avro.decoder import BinaryDecoder
+from pyiceberg.avro.decoder import BinaryDecoder, InMemoryBinaryDecoder
from pyiceberg.avro.encoder import BinaryEncoder
from pyiceberg.avro.reader import Reader
from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve
from pyiceberg.avro.writer import Writer
-from pyiceberg.io import (
- InputFile,
- InputStream,
- OutputFile,
- OutputStream,
-)
+from pyiceberg.io import InputFile, OutputFile, OutputStream
from pyiceberg.schema import Schema
from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol
from pyiceberg.types import (
@@ -134,7 +129,6 @@ class AvroFile(Generic[D]):
"read_schema",
"read_types",
"read_enums",
- "input_stream",
"header",
"schema",
"reader",
@@ -145,7 +139,6 @@ class AvroFile(Generic[D]):
read_schema: Optional[Schema]
read_types: Dict[int, Callable[..., StructProtocol]]
read_enums: Dict[int, Callable[..., Enum]]
- input_stream: InputStream
header: AvroFileHeader
schema: Schema
reader: Reader
@@ -172,8 +165,8 @@ class AvroFile(Generic[D]):
Returns:
A generator returning the AvroStructs.
"""
- self.input_stream = self.input_file.open(seekable=False)
- self.decoder = BinaryDecoder(self.input_stream)
+ with self.input_file.open() as f:
+ self.decoder = InMemoryBinaryDecoder(io.BytesIO(f.read()))
self.header = self._read_header()
self.schema = self.header.get_schema()
if not self.read_schema:
@@ -187,7 +180,6 @@ class AvroFile(Generic[D]):
self, exctype: Optional[Type[BaseException]], excinst:
Optional[BaseException], exctb: Optional[TracebackType]
) -> None:
"""Performs cleanup when exiting the scope of a 'with' statement."""
- self.input_stream.close()
def __iter__(self) -> AvroFile[D]:
"""Returns an iterator for the AvroFile class."""
@@ -206,7 +198,9 @@ class AvroFile(Generic[D]):
if codec := self.header.compression_codec():
block_bytes = codec.decompress(block_bytes)
- self.block = Block(reader=self.reader, block_records=block_records,
block_decoder=BinaryDecoder(io.BytesIO(block_bytes)))
+ self.block = Block(
+ reader=self.reader, block_records=block_records,
block_decoder=InMemoryBinaryDecoder(io.BytesIO(block_bytes))
+ )
return block_records
def __next__(self) -> D:
diff --git a/python/pyiceberg/avro/reader.py b/python/pyiceberg/avro/reader.py
index 7d3d8fcec2..ba9456e892 100644
--- a/python/pyiceberg/avro/reader.py
+++ b/python/pyiceberg/avro/reader.py
@@ -267,10 +267,11 @@ class OptionReader(Reader):
class StructReader(Reader):
- __slots__ = ("field_readers", "create_struct", "struct")
+ __slots__ = ("field_readers", "create_struct", "struct",
"_create_with_keyword", "_field_reader_functions", "_hash")
field_readers: Tuple[Tuple[Optional[int], Reader], ...]
create_struct: Callable[..., StructProtocol]
struct: StructType
+ field_reader_functions = Tuple[Tuple[Optional[str], int,
Optional[Callable[[BinaryDecoder], Any]]], ...]
def __init__(
self,
@@ -282,24 +283,38 @@ class StructReader(Reader):
self.create_struct = create_struct
self.struct = struct
- def read(self, decoder: BinaryDecoder) -> StructProtocol:
try:
# Try initializing the struct, first with the struct keyword
argument
- struct = self.create_struct(struct=self.struct)
+ created_struct = self.create_struct(struct=self.struct)
+ self._create_with_keyword = True
except TypeError as e:
if "'struct' is an invalid keyword argument for" in str(e):
- struct = self.create_struct()
+ created_struct = self.create_struct()
+ self._create_with_keyword = False
else:
raise ValueError(f"Unable to initialize struct:
{self.create_struct}") from e
- if not isinstance(struct, StructProtocol):
+ if not isinstance(created_struct, StructProtocol):
raise ValueError(f"Incompatible with StructProtocol:
{self.create_struct}")
- for pos, field in self.field_readers:
+ reading_callbacks: List[Tuple[Optional[int], Callable[[BinaryDecoder],
Any]]] = []
+ for pos, field in field_readers:
+ if pos is not None:
+ reading_callbacks.append((pos, field.read))
+ else:
+ reading_callbacks.append((None, field.skip))
+
+ self._field_reader_functions = tuple(reading_callbacks)
+ self._hash = hash(self._field_reader_functions)
+
+ def read(self, decoder: BinaryDecoder) -> StructProtocol:
+ struct = self.create_struct(struct=self.struct) if
self._create_with_keyword else self.create_struct()
+
+ for pos, field_reader in self._field_reader_functions:
if pos is not None:
- struct[pos] = field.read(decoder) # later: pass reuse in here
+ struct[pos] = field_reader(decoder) # later: pass reuse in
here
else:
- field.skip(decoder)
+ field_reader(decoder)
return struct
@@ -321,51 +336,88 @@ class StructReader(Reader):
def __hash__(self) -> int:
"""Returns a hashed representation of the StructReader class."""
- return hash(self.field_readers)
+ return self._hash
-@dataclass(frozen=True)
+@dataclass(frozen=False, init=False)
class ListReader(Reader):
- __slots__ = ("element",)
+ __slots__ = ("element", "_is_int_list", "_hash")
element: Reader
+ def __init__(self, element: Reader) -> None:
+ super().__init__()
+ self.element = element
+ self._hash = hash(self.element)
+ self._is_int_list = isinstance(self.element, IntegerReader)
+
def read(self, decoder: BinaryDecoder) -> List[Any]:
- read_items = []
+ read_items: List[Any] = []
block_count = decoder.read_int()
while block_count != 0:
if block_count < 0:
block_count = -block_count
_ = decoder.read_int()
- for _ in range(block_count):
- read_items.append(self.element.read(decoder))
+ if self._is_int_list:
+ decoder.read_ints(block_count, read_items)
+ else:
+ for _ in range(block_count):
+ read_items.append(self.element.read(decoder))
block_count = decoder.read_int()
return read_items
def skip(self, decoder: BinaryDecoder) -> None:
_skip_map_array(decoder, lambda: self.element.skip(decoder))
+ def __hash__(self) -> int:
+ """Returns a hashed representation of the ListReader class."""
+ return self._hash
-@dataclass(frozen=True)
+
+@dataclass(frozen=False, init=False)
class MapReader(Reader):
- __slots__ = ("key", "value")
+ __slots__ = ("key", "value", "_is_int_int", "_is_int_bytes",
"_key_reader", "_value_reader", "_hash")
key: Reader
value: Reader
+ def __init__(self, key: Reader, value: Reader) -> None:
+ super().__init__()
+ self.key = key
+ self.value = value
+ if isinstance(self.key, IntegerReader):
+ self._is_int_int = isinstance(self.value, IntegerReader)
+ self._is_int_bytes = isinstance(self.value, BinaryReader)
+ else:
+ self._is_int_int = False
+ self._is_int_bytes = False
+ self._key_reader = self.key.read
+ self._value_reader = self.value.read
+ self._hash = hash((self.key, self.value))
+
def read(self, decoder: BinaryDecoder) -> Dict[Any, Any]:
- read_items = {}
- block_count = decoder.read_int()
- key_reader = self.key.read
- value_reader = self.value.read
+ read_items: dict[Any, Any] = {}
- while block_count != 0:
- if block_count < 0:
- block_count = -block_count
- # We ignore the block size for now
- _ = decoder.read_int()
- for _ in range(block_count):
- key = key_reader(decoder)
- read_items[key] = value_reader(decoder)
- block_count = decoder.read_int()
+ block_count = decoder.read_int()
+ if self._is_int_int or self._is_int_bytes:
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ # We ignore the block size for now
+ _ = decoder.read_int()
+ if self._is_int_int:
+ decoder.read_int_int_dict(block_count, read_items)
+ else:
+ decoder.read_int_bytes_dict(block_count, read_items)
+ block_count = decoder.read_int()
+ else:
+ while block_count != 0:
+ if block_count < 0:
+ block_count = -block_count
+ # We ignore the block size for now
+ _ = decoder.read_int()
+ for _ in range(block_count):
+ key = self._key_reader(decoder)
+ read_items[key] = self._value_reader(decoder)
+ block_count = decoder.read_int()
return read_items
@@ -375,3 +427,7 @@ class MapReader(Reader):
self.value.skip(decoder)
_skip_map_array(decoder, skip)
+
+ def __hash__(self) -> int:
+ """Returns a hashed representation of the MapReader class."""
+ return self._hash
diff --git a/python/pyiceberg/manifest.py b/python/pyiceberg/manifest.py
index ce0e47fc24..879872f892 100644
--- a/python/pyiceberg/manifest.py
+++ b/python/pyiceberg/manifest.py
@@ -231,6 +231,8 @@ MANIFEST_ENTRY_SCHEMA = Schema(
NestedField(2, "data_file", DATA_FILE_TYPE, required=True),
)
+MANIFEST_ENTRY_SCHEMA_STRUCT = MANIFEST_ENTRY_SCHEMA.as_struct()
+
class ManifestEntry(Record):
__slots__ = ("status", "snapshot_id", "data_sequence_number",
"file_sequence_number", "data_file")
@@ -241,7 +243,7 @@ class ManifestEntry(Record):
data_file: DataFile
def __init__(self, *data: Any, **named_data: Any) -> None:
- super().__init__(*data, **{"struct":
MANIFEST_ENTRY_SCHEMA.as_struct(), **named_data})
+ super().__init__(*data, **{"struct": MANIFEST_ENTRY_SCHEMA_STRUCT,
**named_data})
PARTITION_FIELD_SUMMARY_TYPE = StructType(
@@ -281,6 +283,8 @@ MANIFEST_FILE_SCHEMA: Schema = Schema(
NestedField(519, "key_metadata", BinaryType(), required=False),
)
+MANIFEST_FILE_SCHEMA_STRUCT = MANIFEST_FILE_SCHEMA.as_struct()
+
POSITIONAL_DELETE_SCHEMA = Schema(
NestedField(2147483546, "file_path", StringType()),
NestedField(2147483545, "pos", IntegerType())
)
@@ -321,7 +325,7 @@ class ManifestFile(Record):
key_metadata: Optional[bytes]
def __init__(self, *data: Any, **named_data: Any) -> None:
- super().__init__(*data, **{"struct": MANIFEST_FILE_SCHEMA.as_struct(),
**named_data})
+ super().__init__(*data, **{"struct": MANIFEST_FILE_SCHEMA_STRUCT,
**named_data})
def has_added_files(self) -> bool:
return self.added_files_count is None or self.added_files_count > 0
diff --git a/python/tests/avro/test_decoder.py
b/python/tests/avro/test_decoder.py
index ee9a5b210e..b1ab97fb1d 100644
--- a/python/tests/avro/test_decoder.py
+++ b/python/tests/avro/test_decoder.py
@@ -26,73 +26,84 @@ from uuid import UUID
import pytest
-from pyiceberg.avro.decoder import BinaryDecoder
+from pyiceberg.avro.decoder import BinaryDecoder, InMemoryBinaryDecoder,
StreamingBinaryDecoder
from pyiceberg.avro.resolver import resolve
from pyiceberg.io import InputStream
from pyiceberg.types import DoubleType, FloatType
+AVAILABLE_DECODERS = [StreamingBinaryDecoder, InMemoryBinaryDecoder]
-def test_read_decimal_from_fixed() -> None:
+
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_decimal_from_fixed(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00\x00\x00\x05\x6A\x48\x1C\xFB\x2C\x7C\x50\x00")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
actual = decoder.read_decimal_from_fixed(28, 15, 12)
expected = Decimal("99892.123400000000000")
assert actual == expected
-def test_read_boolean_true() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_boolean_true(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x01")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_boolean() is True
-def test_read_boolean_false() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_boolean_false(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_boolean() is False
-def test_skip_boolean() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_boolean(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00")
- decoder = BinaryDecoder(mis)
- assert mis.tell() == 0
+ decoder = decoder_class(mis)
+ assert decoder.tell() == 0
decoder.skip_boolean()
- assert mis.tell() == 1
+ assert decoder.tell() == 1
-def test_read_int() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_int(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x18")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_int() == 12
-def test_skip_int() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_int(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x18")
- decoder = BinaryDecoder(mis)
- assert mis.tell() == 0
+ decoder = decoder_class(mis)
+ assert decoder.tell() == 0
decoder.skip_int()
- assert mis.tell() == 1
+ assert decoder.tell() == 1
-def test_read_decimal() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_decimal(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x18\x00\x00\x00\x05\x6A\x48\x1C\xFB\x2C\x7C\x50\x00")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
actual = decoder.read_decimal_from_bytes(28, 15)
expected = Decimal("99892.123400000000000")
assert actual == expected
-def test_decimal_from_fixed_big() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_decimal_from_fixed_big(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x0E\xC2\x02\xE9\x06\x16\x33\x49\x77\x67\xA8\x00")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
actual = decoder.read_decimal_from_fixed(28, 15, 12)
expected = Decimal("4567335489766.998340000000000")
assert actual == expected
-def test_read_negative_bytes() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_negative_bytes(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
with pytest.raises(ValueError) as exc_info:
decoder.read(-1)
@@ -130,92 +141,107 @@ class OneByteAtATimeInputStream(InputStream):
self.close()
-def test_read_single_byte_at_the_time() -> None:
- decoder = BinaryDecoder(OneByteAtATimeInputStream())
+# InMemoryBinaryDecoder doesn't work for a byte at a time reading
[email protected]("decoder_class", [StreamingBinaryDecoder])
+def test_read_single_byte_at_the_time(decoder_class: Type[BinaryDecoder]) ->
None:
+ decoder = decoder_class(OneByteAtATimeInputStream())
assert decoder.read(2) == b"\x01\x02"
-def test_read_float() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_float(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00\x00\x9A\x41")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_float() == 19.25
-def test_skip_float() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_float(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00\x00\x9A\x41")
- decoder = BinaryDecoder(mis)
- assert mis.tell() == 0
+ decoder = decoder_class(mis)
+ assert decoder.tell() == 0
decoder.skip_float()
- assert mis.tell() == 4
+ assert decoder.tell() == 4
-def test_read_double() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_double(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00\x00\x00\x00\x00\x40\x33\x40")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_double() == 19.25
-def test_skip_double() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_double(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00\x00\x00\x00\x00\x40\x33\x40")
- decoder = BinaryDecoder(mis)
- assert mis.tell() == 0
+ decoder = decoder_class(mis)
+ assert decoder.tell() == 0
decoder.skip_double()
- assert mis.tell() == 8
+ assert decoder.tell() == 8
-def test_read_uuid_from_fixed() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_uuid_from_fixed(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x12\x34\x56\x78" * 4)
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_uuid_from_fixed() ==
UUID("{12345678-1234-5678-1234-567812345678}")
-def test_read_time_millis() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_time_millis(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\xBC\x7D")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_time_millis().microsecond == 30000
-def test_read_time_micros() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_time_micros(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\xBC\x7D")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_time_micros().microsecond == 8030
-def test_read_timestamp_micros() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_timestamp_micros(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\xBC\x7D")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_timestamp_micros() == datetime(1970, 1, 1, 0, 0, 0,
8030)
-def test_read_timestamptz_micros() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_timestamptz_micros(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\xBC\x7D")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_timestamptz_micros() == datetime(1970, 1, 1, 0, 0, 0,
8030, tzinfo=timezone.utc)
-def test_read_bytes() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_bytes(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x08\x01\x02\x03\x04")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
actual = decoder.read_bytes()
assert actual == b"\x01\x02\x03\x04"
-def test_read_utf8() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_utf8(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x04\x76\x6F")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
assert decoder.read_utf8() == "vo"
-def test_skip_utf8() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_utf8(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x04\x76\x6F")
- decoder = BinaryDecoder(mis)
- assert mis.tell() == 0
+ decoder = decoder_class(mis)
+ assert decoder.tell() == 0
decoder.skip_utf8()
- assert mis.tell() == 3
+ assert decoder.tell() == 3
-def test_read_int_as_float() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_int_as_float(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x00\x00\x9A\x41")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
reader = resolve(FloatType(), DoubleType())
assert reader.read(decoder) == 19.25
diff --git a/python/tests/avro/test_reader.py b/python/tests/avro/test_reader.py
index 84a418e1db..416a11abb2 100644
--- a/python/tests/avro/test_reader.py
+++ b/python/tests/avro/test_reader.py
@@ -17,10 +17,11 @@
# pylint:disable=protected-access
import io
import json
+from typing import Type
import pytest
-from pyiceberg.avro.decoder import BinaryDecoder
+from pyiceberg.avro.decoder import BinaryDecoder, InMemoryBinaryDecoder,
StreamingBinaryDecoder
from pyiceberg.avro.file import AvroFile
from pyiceberg.avro.reader import (
BinaryReader,
@@ -63,6 +64,8 @@ from pyiceberg.types import (
UUIDType,
)
+AVAILABLE_DECODERS = [StreamingBinaryDecoder, InMemoryBinaryDecoder]
+
def test_read_header(generated_manifest_entry_file: str,
iceberg_manifest_entry_schema: Schema) -> None:
with AvroFile[ManifestEntry](
@@ -335,18 +338,19 @@ def test_uuid_reader() -> None:
assert construct_reader(UUIDType()) == UUIDReader()
-def test_read_struct() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_struct(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x18")
- decoder = BinaryDecoder(mis)
-
+ decoder = decoder_class(mis)
struct = StructType(NestedField(1, "id", IntegerType(), required=True))
result = StructReader(((0, IntegerReader()),), Record,
struct).read(decoder)
assert repr(result) == "Record[id=12]"
-def test_read_struct_lambda() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_struct_lambda(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x18")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
struct = StructType(NestedField(1, "id", IntegerType(), required=True))
# You can also pass in an arbitrary function that returns a struct
@@ -356,9 +360,10 @@ def test_read_struct_lambda() -> None:
assert repr(result) == "Record[id=12]"
-def test_read_not_struct_type() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_not_struct_type(decoder_class: Type[BinaryDecoder]) -> None:
mis = io.BytesIO(b"\x18")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
struct = StructType(NestedField(1, "id", IntegerType(), required=True))
with pytest.raises(ValueError) as exc_info:
@@ -367,9 +372,10 @@ def test_read_not_struct_type() -> None:
assert "Incompatible with StructProtocol: <class 'str'>" in
str(exc_info.value)
-def test_read_struct_exception_handling() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_struct_exception_handling(decoder_class: Type[BinaryDecoder]) ->
None:
mis = io.BytesIO(b"\x18")
- decoder = BinaryDecoder(mis)
+ decoder = decoder_class(mis)
def raise_err(struct: StructType) -> None:
raise TypeError("boom")