This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new b5b916bd90 Python: Improve AVRO reading speed (#8084)
b5b916bd90 is described below

commit b5b916bd90d5fd60e736fade47874b52ca30ab39
Author: Rusty Conover <[email protected]>
AuthorDate: Tue Jul 25 02:33:55 2023 -0400

    Python: Improve AVRO reading speed (#8084)
    
    * Python: Improve Avro read performance
    
    Utilize __slots__ for various Avro classes.
    
    Memoize the key/value reader functions in MapReader.
    
    * Python: improve Avro parsing by caching struct types.
    
    Cache the struct types for ManifestEntry and ManifestFile schemas.
    
    Cache the _position_to_field_name lookup in Records if a struct type is 
passed.
    
    Improve StructReader to use slots and avoid isinstance() lookups on each 
time
    the structure is read, by performing the check in __init__().
    
    Improve StructReader to determine the arguments to create_struct() once in 
init
    and reuse the result each time the structure is read, avoid a try block.
    
    * Python: Improve Avro Parse speed
    
    Cache the hash value of the Struct types so that when new
    Record objects are created time is not spent recomputing
    the hash value.
    
    Change the implementation of the Record class to cache the
    mapping from field name to position.  Also utilize an array
    rather than a dictionary for these lookups.
    
    * fix some lints
    
    * fix some type casing
    
    * fix: adjust PR comments
    
    * fix: address PR feedback.
    
    * Add additional tests for InMemoryBinaryDecoder
    
    * Load entire Avro file into memory rather than streaming it
    
    * fix: disable seeking on input file
    
    * Update python/pyiceberg/avro/file.py
    
    ---------
    
    Co-authored-by: Fokko Driesprong <[email protected]>
---
 python/pyiceberg/avro/decoder.py  | 192 ++++++++++++++++++++++++++++++++------
 python/pyiceberg/avro/file.py     |  20 ++--
 python/pyiceberg/avro/reader.py   | 114 ++++++++++++++++------
 python/pyiceberg/manifest.py      |   8 +-
 python/tests/avro/test_decoder.py | 140 ++++++++++++++++-----------
 python/tests/avro/test_reader.py  |  26 ++++--
 6 files changed, 362 insertions(+), 138 deletions(-)

diff --git a/python/pyiceberg/avro/decoder.py b/python/pyiceberg/avro/decoder.py
index e680c00bcd..35a1651192 100644
--- a/python/pyiceberg/avro/decoder.py
+++ b/python/pyiceberg/avro/decoder.py
@@ -15,9 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 import decimal
+from abc import ABC, abstractmethod
 from datetime import datetime, time
 from io import SEEK_CUR
-from typing import List
+from typing import Dict, List
 from uuid import UUID
 
 from pyiceberg.avro import STRUCT_DOUBLE, STRUCT_FLOAT
@@ -26,39 +27,24 @@ from pyiceberg.utils.datetime import micros_to_time, 
micros_to_timestamp, micros
 from pyiceberg.utils.decimal import unscaled_to_decimal
 
 
-class BinaryDecoder:
+class BinaryDecoder(ABC):
     """Read leaf values."""
 
-    __slots__ = "_input_stream"
-    _input_stream: InputStream
-
+    @abstractmethod
     def __init__(self, input_stream: InputStream) -> None:
-        """Reader is a Python object on which we can call read, seek, and 
tell."""
-        self._input_stream = input_stream
+        """Create the decoder."""
+
+    @abstractmethod
+    def tell(self) -> int:
+        """Return the current position."""
 
+    @abstractmethod
     def read(self, n: int) -> bytes:
         """Read n bytes."""
-        if n < 0:
-            raise ValueError(f"Requested {n} bytes to read, expected positive 
integer.")
-        data: List[bytes] = []
-
-        n_remaining = n
-        while n_remaining > 0:
-            data_read = self._input_stream.read(n_remaining)
-            read_len = len(data_read)
-            if read_len == n:
-                # If we read everything, we return directly
-                # otherwise we'll continue to fetch the rest
-                return data_read
-            elif read_len <= 0:
-                raise EOFError(f"EOF: read {read_len} bytes")
-            data.append(data_read)
-            n_remaining -= read_len
-
-        return b"".join(data)
 
+    @abstractmethod
     def skip(self, n: int) -> None:
-        self._input_stream.seek(n, SEEK_CUR)
+        """Skip n bytes."""
 
     def read_boolean(self) -> bool:
         """Reads a value from the stream as a boolean.
@@ -69,7 +55,7 @@ class BinaryDecoder:
         return ord(self.read(1)) == 1
 
     def read_int(self) -> int:
-        """Reads a value from the stream as an integer.
+        """Reads an int/long value.
 
         int/long values are written using variable-length, zigzag coding.
         """
@@ -83,6 +69,25 @@ class BinaryDecoder:
         datum = (n >> 1) ^ -(n & 1)
         return datum
 
+    def read_ints(self, n: int, dest: List[int]) -> None:
+        """Reads a list of integers."""
+        for _ in range(n):
+            dest.append(self.read_int())
+
+    def read_int_int_dict(self, n: int, dest: Dict[int, int]) -> None:
+        """Reads a dictionary of integers for keys and values into a 
destination dictionary."""
+        for _ in range(n):
+            k = self.read_int()
+            v = self.read_int()
+            dest[k] = v
+
+    def read_int_bytes_dict(self, n: int, dest: Dict[int, bytes]) -> None:
+        """Reads a dictionary of integers for keys and bytes for values into a 
destination dictionary."""
+        for _ in range(n):
+            k = self.read_int()
+            v = self.read_bytes()
+            dest[k] = v
+
     def read_float(self) -> float:
         """Reads a value from the stream as a float.
 
@@ -191,3 +196,136 @@ class BinaryDecoder:
 
     def skip_utf8(self) -> None:
         self.skip_bytes()
+
+
+class StreamingBinaryDecoder(BinaryDecoder):
+    """Read leaf values."""
+
+    __slots__ = "_input_stream"
+    _input_stream: InputStream
+
+    def __init__(self, input_stream: InputStream) -> None:
+        """Reader is a Python object on which we can call read, seek, and 
tell."""
+        super().__init__(input_stream)
+        self._input_stream = input_stream
+
+    def tell(self) -> int:
+        """Return the current stream position."""
+        return self._input_stream.tell()
+
+    def read(self, n: int) -> bytes:
+        """Read n bytes."""
+        if n < 0:
+            raise ValueError(f"Requested {n} bytes to read, expected positive 
integer.")
+        data: List[bytes] = []
+
+        n_remaining = n
+        while n_remaining > 0:
+            data_read = self._input_stream.read(n_remaining)
+            read_len = len(data_read)
+            if read_len == n:
+                # If we read everything, we return directly
+                # otherwise we'll continue to fetch the rest
+                return data_read
+            elif read_len <= 0:
+                raise EOFError(f"EOF: read {read_len} bytes")
+            data.append(data_read)
+            n_remaining -= read_len
+
+        return b"".join(data)
+
+    def skip(self, n: int) -> None:
+        self._input_stream.seek(n, SEEK_CUR)
+
+
+class InMemoryBinaryDecoder(BinaryDecoder):
+    """Implement a BinaryDecoder that reads from an in-memory buffer.
+
+    This may be more efficient if the entire block is already in memory
+    as it does not need to interact with the I/O subsystem.
+    """
+
+    __slots__ = ["_contents", "_position", "_size"]
+    _contents: bytes
+    _position: int
+    _size: int
+
+    def __init__(self, input_stream: InputStream) -> None:
+        """Reader is a Python object on which we can call read, seek, and 
tell."""
+        super().__init__(input_stream)
+        self._contents = input_stream.read()
+        self._size = len(self._contents)
+        self._position = 0
+
+    def tell(self) -> int:
+        """Return the current stream position."""
+        return self._position
+
+    def read(self, n: int) -> bytes:
+        """Read n bytes."""
+        if n < 0:
+            raise ValueError(f"Requested {n} bytes to read, expected positive 
integer.")
+        if self._position + n > self._size:
+            raise EOFError(f"EOF: read {n} bytes")
+        r = self._contents[self._position : self._position + n]
+        self._position += n
+        return r
+
+    def skip(self, n: int) -> None:
+        self._position += n
+
+    def read_boolean(self) -> bool:
+        """Reads a value from the stream as a boolean.
+
+        A boolean is written as a single byte
+        whose value is either 0 (false) or 1 (true).
+        """
+        r = self._contents[self._position]
+        self._position += 1
+        return r != 0
+
+    def read_int(self) -> int:
+        """Reads a value from the stream as an integer.
+
+        int/long values are written using variable-length, zigzag coding.
+        """
+        if self._position == self._size:
+            raise EOFError("EOF: read 1 byte")
+        b = self._contents[self._position]
+        self._position += 1
+        n = b & 0x7F
+        shift = 7
+        while b & 0x80:
+            b = self._contents[self._position]
+            self._position += 1
+            n |= (b & 0x7F) << shift
+            shift += 7
+        return (n >> 1) ^ -(n & 1)
+
+    def read_int_bytes_dict(self, n: int, dest: Dict[int, bytes]) -> None:
+        """Reads a dictionary of integers for keys and bytes for values into a 
destination dict."""
+        for _ in range(n):
+            k = self.read_int()
+
+            byte_length = self.read_int()
+            if byte_length <= 0:
+                dest[k] = b""
+            else:
+                dest[k] = self._contents[self._position : self._position + 
byte_length]
+                self._position += byte_length
+
+    def read_bytes(self) -> bytes:
+        """Bytes are encoded as a long followed by that many bytes of data."""
+        num_bytes = self.read_int()
+        if num_bytes <= 0:
+            return b""
+        r = self._contents[self._position : self._position + num_bytes]
+        self._position += num_bytes
+        return r
+
+    def skip_int(self) -> None:
+        b = self._contents[self._position]
+        self._position += 1
+        while b & 0x80:
+            b = self._contents[self._position]
+            self._position += 1
diff --git a/python/pyiceberg/avro/file.py b/python/pyiceberg/avro/file.py
index 520eda9ce1..f780a8a30c 100644
--- a/python/pyiceberg/avro/file.py
+++ b/python/pyiceberg/avro/file.py
@@ -35,17 +35,12 @@ from typing import (
 )
 
 from pyiceberg.avro.codecs import KNOWN_CODECS, Codec
-from pyiceberg.avro.decoder import BinaryDecoder
+from pyiceberg.avro.decoder import BinaryDecoder, InMemoryBinaryDecoder
 from pyiceberg.avro.encoder import BinaryEncoder
 from pyiceberg.avro.reader import Reader
 from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve
 from pyiceberg.avro.writer import Writer
-from pyiceberg.io import (
-    InputFile,
-    InputStream,
-    OutputFile,
-    OutputStream,
-)
+from pyiceberg.io import InputFile, OutputFile, OutputStream
 from pyiceberg.schema import Schema
 from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol
 from pyiceberg.types import (
@@ -134,7 +129,6 @@ class AvroFile(Generic[D]):
         "read_schema",
         "read_types",
         "read_enums",
-        "input_stream",
         "header",
         "schema",
         "reader",
@@ -145,7 +139,6 @@ class AvroFile(Generic[D]):
     read_schema: Optional[Schema]
     read_types: Dict[int, Callable[..., StructProtocol]]
     read_enums: Dict[int, Callable[..., Enum]]
-    input_stream: InputStream
     header: AvroFileHeader
     schema: Schema
     reader: Reader
@@ -172,8 +165,8 @@ class AvroFile(Generic[D]):
         Returns:
             A generator returning the AvroStructs.
         """
-        self.input_stream = self.input_file.open(seekable=False)
-        self.decoder = BinaryDecoder(self.input_stream)
+        with self.input_file.open() as f:
+            self.decoder = InMemoryBinaryDecoder(io.BytesIO(f.read()))
         self.header = self._read_header()
         self.schema = self.header.get_schema()
         if not self.read_schema:
@@ -187,7 +180,6 @@ class AvroFile(Generic[D]):
         self, exctype: Optional[Type[BaseException]], excinst: 
Optional[BaseException], exctb: Optional[TracebackType]
     ) -> None:
         """Performs cleanup when exiting the scope of a 'with' statement."""
-        self.input_stream.close()
 
     def __iter__(self) -> AvroFile[D]:
         """Returns an iterator for the AvroFile class."""
@@ -206,7 +198,9 @@ class AvroFile(Generic[D]):
         if codec := self.header.compression_codec():
             block_bytes = codec.decompress(block_bytes)
 
-        self.block = Block(reader=self.reader, block_records=block_records, 
block_decoder=BinaryDecoder(io.BytesIO(block_bytes)))
+        self.block = Block(
+            reader=self.reader, block_records=block_records, 
block_decoder=InMemoryBinaryDecoder(io.BytesIO(block_bytes))
+        )
         return block_records
 
     def __next__(self) -> D:
diff --git a/python/pyiceberg/avro/reader.py b/python/pyiceberg/avro/reader.py
index 7d3d8fcec2..ba9456e892 100644
--- a/python/pyiceberg/avro/reader.py
+++ b/python/pyiceberg/avro/reader.py
@@ -267,10 +267,11 @@ class OptionReader(Reader):
 
 
 class StructReader(Reader):
-    __slots__ = ("field_readers", "create_struct", "struct")
+    __slots__ = ("field_readers", "create_struct", "struct", 
"_create_with_keyword", "_field_reader_functions", "_hash")
     field_readers: Tuple[Tuple[Optional[int], Reader], ...]
     create_struct: Callable[..., StructProtocol]
     struct: StructType
+    field_reader_functions = Tuple[Tuple[Optional[str], int, 
Optional[Callable[[BinaryDecoder], Any]]], ...]
 
     def __init__(
         self,
@@ -282,24 +283,38 @@ class StructReader(Reader):
         self.create_struct = create_struct
         self.struct = struct
 
-    def read(self, decoder: BinaryDecoder) -> StructProtocol:
         try:
             # Try initializing the struct, first with the struct keyword 
argument
-            struct = self.create_struct(struct=self.struct)
+            created_struct = self.create_struct(struct=self.struct)
+            self._create_with_keyword = True
         except TypeError as e:
             if "'struct' is an invalid keyword argument for" in str(e):
-                struct = self.create_struct()
+                created_struct = self.create_struct()
+                self._create_with_keyword = False
             else:
                 raise ValueError(f"Unable to initialize struct: 
{self.create_struct}") from e
 
-        if not isinstance(struct, StructProtocol):
+        if not isinstance(created_struct, StructProtocol):
             raise ValueError(f"Incompatible with StructProtocol: 
{self.create_struct}")
 
-        for pos, field in self.field_readers:
+        reading_callbacks: List[Tuple[Optional[int], Callable[[BinaryDecoder], 
Any]]] = []
+        for pos, field in field_readers:
+            if pos is not None:
+                reading_callbacks.append((pos, field.read))
+            else:
+                reading_callbacks.append((None, field.skip))
+
+        self._field_reader_functions = tuple(reading_callbacks)
+        self._hash = hash(self._field_reader_functions)
+
+    def read(self, decoder: BinaryDecoder) -> StructProtocol:
+        struct = self.create_struct(struct=self.struct) if 
self._create_with_keyword else self.create_struct()
+
+        for pos, field_reader in self._field_reader_functions:
             if pos is not None:
-                struct[pos] = field.read(decoder)  # later: pass reuse in here
+                struct[pos] = field_reader(decoder)  # later: pass reuse in 
here
             else:
-                field.skip(decoder)
+                field_reader(decoder)
 
         return struct
 
@@ -321,51 +336,88 @@ class StructReader(Reader):
 
     def __hash__(self) -> int:
         """Returns a hashed representation of the StructReader class."""
-        return hash(self.field_readers)
+        return self._hash
 
 
-@dataclass(frozen=True)
+@dataclass(frozen=False, init=False)
 class ListReader(Reader):
-    __slots__ = ("element",)
+    __slots__ = ("element", "_is_int_list", "_hash")
     element: Reader
 
+    def __init__(self, element: Reader) -> None:
+        super().__init__()
+        self.element = element
+        self._hash = hash(self.element)
+        self._is_int_list = isinstance(self.element, IntegerReader)
+
     def read(self, decoder: BinaryDecoder) -> List[Any]:
-        read_items = []
+        read_items: List[Any] = []
         block_count = decoder.read_int()
         while block_count != 0:
             if block_count < 0:
                 block_count = -block_count
                 _ = decoder.read_int()
-            for _ in range(block_count):
-                read_items.append(self.element.read(decoder))
+            if self._is_int_list:
+                decoder.read_ints(block_count, read_items)
+            else:
+                for _ in range(block_count):
+                    read_items.append(self.element.read(decoder))
             block_count = decoder.read_int()
         return read_items
 
     def skip(self, decoder: BinaryDecoder) -> None:
         _skip_map_array(decoder, lambda: self.element.skip(decoder))
 
+    def __hash__(self) -> int:
+        """Returns a hashed representation of the ListReader class."""
+        return self._hash
 
-@dataclass(frozen=True)
+
+@dataclass(frozen=False, init=False)
 class MapReader(Reader):
-    __slots__ = ("key", "value")
+    __slots__ = ("key", "value", "_is_int_int", "_is_int_bytes", 
"_key_reader", "_value_reader", "_hash")
     key: Reader
     value: Reader
 
+    def __init__(self, key: Reader, value: Reader) -> None:
+        super().__init__()
+        self.key = key
+        self.value = value
+        if isinstance(self.key, IntegerReader):
+            self._is_int_int = isinstance(self.value, IntegerReader)
+            self._is_int_bytes = isinstance(self.value, BinaryReader)
+        else:
+            self._is_int_int = False
+            self._is_int_bytes = False
+            self._key_reader = self.key.read
+            self._value_reader = self.value.read
+        self._hash = hash((self.key, self.value))
+
     def read(self, decoder: BinaryDecoder) -> Dict[Any, Any]:
-        read_items = {}
-        block_count = decoder.read_int()
-        key_reader = self.key.read
-        value_reader = self.value.read
+        read_items: dict[Any, Any] = {}
 
-        while block_count != 0:
-            if block_count < 0:
-                block_count = -block_count
-                # We ignore the block size for now
-                _ = decoder.read_int()
-            for _ in range(block_count):
-                key = key_reader(decoder)
-                read_items[key] = value_reader(decoder)
-            block_count = decoder.read_int()
+        block_count = decoder.read_int()
+        if self._is_int_int or self._is_int_bytes:
+            while block_count != 0:
+                if block_count < 0:
+                    block_count = -block_count
+                    # We ignore the block size for now
+                    _ = decoder.read_int()
+                if self._is_int_int:
+                    decoder.read_int_int_dict(block_count, read_items)
+                else:
+                    decoder.read_int_bytes_dict(block_count, read_items)
+                block_count = decoder.read_int()
+        else:
+            while block_count != 0:
+                if block_count < 0:
+                    block_count = -block_count
+                    # We ignore the block size for now
+                    _ = decoder.read_int()
+                for _ in range(block_count):
+                    key = self._key_reader(decoder)
+                    read_items[key] = self._value_reader(decoder)
+                block_count = decoder.read_int()
 
         return read_items
 
@@ -375,3 +427,7 @@ class MapReader(Reader):
             self.value.skip(decoder)
 
         _skip_map_array(decoder, skip)
+
+    def __hash__(self) -> int:
+        """Returns a hashed representation of the MapReader class."""
+        return self._hash
diff --git a/python/pyiceberg/manifest.py b/python/pyiceberg/manifest.py
index ce0e47fc24..879872f892 100644
--- a/python/pyiceberg/manifest.py
+++ b/python/pyiceberg/manifest.py
@@ -231,6 +231,8 @@ MANIFEST_ENTRY_SCHEMA = Schema(
     NestedField(2, "data_file", DATA_FILE_TYPE, required=True),
 )
 
+MANIFEST_ENTRY_SCHEMA_STRUCT = MANIFEST_ENTRY_SCHEMA.as_struct()
+
 
 class ManifestEntry(Record):
     __slots__ = ("status", "snapshot_id", "data_sequence_number", 
"file_sequence_number", "data_file")
@@ -241,7 +243,7 @@ class ManifestEntry(Record):
     data_file: DataFile
 
     def __init__(self, *data: Any, **named_data: Any) -> None:
-        super().__init__(*data, **{"struct": 
MANIFEST_ENTRY_SCHEMA.as_struct(), **named_data})
+        super().__init__(*data, **{"struct": MANIFEST_ENTRY_SCHEMA_STRUCT, 
**named_data})
 
 
 PARTITION_FIELD_SUMMARY_TYPE = StructType(
@@ -281,6 +283,8 @@ MANIFEST_FILE_SCHEMA: Schema = Schema(
     NestedField(519, "key_metadata", BinaryType(), required=False),
 )
 
+MANIFEST_FILE_SCHEMA_STRUCT = MANIFEST_FILE_SCHEMA.as_struct()
+
 POSITIONAL_DELETE_SCHEMA = Schema(
     NestedField(2147483546, "file_path", StringType()), 
NestedField(2147483545, "pos", IntegerType())
 )
@@ -321,7 +325,7 @@ class ManifestFile(Record):
     key_metadata: Optional[bytes]
 
     def __init__(self, *data: Any, **named_data: Any) -> None:
-        super().__init__(*data, **{"struct": MANIFEST_FILE_SCHEMA.as_struct(), 
**named_data})
+        super().__init__(*data, **{"struct": MANIFEST_FILE_SCHEMA_STRUCT, 
**named_data})
 
     def has_added_files(self) -> bool:
         return self.added_files_count is None or self.added_files_count > 0
diff --git a/python/tests/avro/test_decoder.py 
b/python/tests/avro/test_decoder.py
index ee9a5b210e..b1ab97fb1d 100644
--- a/python/tests/avro/test_decoder.py
+++ b/python/tests/avro/test_decoder.py
@@ -26,73 +26,84 @@ from uuid import UUID
 
 import pytest
 
-from pyiceberg.avro.decoder import BinaryDecoder
+from pyiceberg.avro.decoder import BinaryDecoder, InMemoryBinaryDecoder, 
StreamingBinaryDecoder
 from pyiceberg.avro.resolver import resolve
 from pyiceberg.io import InputStream
 from pyiceberg.types import DoubleType, FloatType
 
+AVAILABLE_DECODERS = [StreamingBinaryDecoder, InMemoryBinaryDecoder]
 
-def test_read_decimal_from_fixed() -> None:
+
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_decimal_from_fixed(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00\x00\x00\x05\x6A\x48\x1C\xFB\x2C\x7C\x50\x00")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     actual = decoder.read_decimal_from_fixed(28, 15, 12)
     expected = Decimal("99892.123400000000000")
     assert actual == expected
 
 
-def test_read_boolean_true() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_boolean_true(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x01")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_boolean() is True
 
 
-def test_read_boolean_false() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_boolean_false(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_boolean() is False
 
 
-def test_skip_boolean() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_boolean(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00")
-    decoder = BinaryDecoder(mis)
-    assert mis.tell() == 0
+    decoder = decoder_class(mis)
+    assert decoder.tell() == 0
     decoder.skip_boolean()
-    assert mis.tell() == 1
+    assert decoder.tell() == 1
 
 
-def test_read_int() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_int(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x18")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_int() == 12
 
 
-def test_skip_int() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_int(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x18")
-    decoder = BinaryDecoder(mis)
-    assert mis.tell() == 0
+    decoder = decoder_class(mis)
+    assert decoder.tell() == 0
     decoder.skip_int()
-    assert mis.tell() == 1
+    assert decoder.tell() == 1
 
 
-def test_read_decimal() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_decimal(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x18\x00\x00\x00\x05\x6A\x48\x1C\xFB\x2C\x7C\x50\x00")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     actual = decoder.read_decimal_from_bytes(28, 15)
     expected = Decimal("99892.123400000000000")
     assert actual == expected
 
 
-def test_decimal_from_fixed_big() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_decimal_from_fixed_big(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x0E\xC2\x02\xE9\x06\x16\x33\x49\x77\x67\xA8\x00")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     actual = decoder.read_decimal_from_fixed(28, 15, 12)
     expected = Decimal("4567335489766.998340000000000")
     assert actual == expected
 
 
-def test_read_negative_bytes() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_negative_bytes(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
 
     with pytest.raises(ValueError) as exc_info:
         decoder.read(-1)
@@ -130,92 +141,107 @@ class OneByteAtATimeInputStream(InputStream):
         self.close()
 
 
-def test_read_single_byte_at_the_time() -> None:
-    decoder = BinaryDecoder(OneByteAtATimeInputStream())
+# InMemoryBinaryDecoder doesn't work for a byte at a time reading
[email protected]("decoder_class", [StreamingBinaryDecoder])
+def test_read_single_byte_at_the_time(decoder_class: Type[BinaryDecoder]) -> 
None:
+    decoder = decoder_class(OneByteAtATimeInputStream())
     assert decoder.read(2) == b"\x01\x02"
 
 
-def test_read_float() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_float(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00\x00\x9A\x41")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_float() == 19.25
 
 
-def test_skip_float() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_float(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00\x00\x9A\x41")
-    decoder = BinaryDecoder(mis)
-    assert mis.tell() == 0
+    decoder = decoder_class(mis)
+    assert decoder.tell() == 0
     decoder.skip_float()
-    assert mis.tell() == 4
+    assert decoder.tell() == 4
 
 
-def test_read_double() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_double(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00\x00\x00\x00\x00\x40\x33\x40")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_double() == 19.25
 
 
-def test_skip_double() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_double(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00\x00\x00\x00\x00\x40\x33\x40")
-    decoder = BinaryDecoder(mis)
-    assert mis.tell() == 0
+    decoder = decoder_class(mis)
+    assert decoder.tell() == 0
     decoder.skip_double()
-    assert mis.tell() == 8
+    assert decoder.tell() == 8
 
 
-def test_read_uuid_from_fixed() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_uuid_from_fixed(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x12\x34\x56\x78" * 4)
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_uuid_from_fixed() == 
UUID("{12345678-1234-5678-1234-567812345678}")
 
 
-def test_read_time_millis() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_time_millis(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\xBC\x7D")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_time_millis().microsecond == 30000
 
 
-def test_read_time_micros() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_time_micros(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\xBC\x7D")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_time_micros().microsecond == 8030
 
 
-def test_read_timestamp_micros() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_timestamp_micros(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\xBC\x7D")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_timestamp_micros() == datetime(1970, 1, 1, 0, 0, 0, 
8030)
 
 
-def test_read_timestamptz_micros() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_timestamptz_micros(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\xBC\x7D")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_timestamptz_micros() == datetime(1970, 1, 1, 0, 0, 0, 
8030, tzinfo=timezone.utc)
 
 
-def test_read_bytes() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_bytes(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x08\x01\x02\x03\x04")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     actual = decoder.read_bytes()
     assert actual == b"\x01\x02\x03\x04"
 
 
-def test_read_utf8() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_utf8(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x04\x76\x6F")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     assert decoder.read_utf8() == "vo"
 
 
-def test_skip_utf8() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_skip_utf8(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x04\x76\x6F")
-    decoder = BinaryDecoder(mis)
-    assert mis.tell() == 0
+    decoder = decoder_class(mis)
+    assert decoder.tell() == 0
     decoder.skip_utf8()
-    assert mis.tell() == 3
+    assert decoder.tell() == 3
 
 
-def test_read_int_as_float() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_int_as_float(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x00\x00\x9A\x41")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
     reader = resolve(FloatType(), DoubleType())
     assert reader.read(decoder) == 19.25
diff --git a/python/tests/avro/test_reader.py b/python/tests/avro/test_reader.py
index 84a418e1db..416a11abb2 100644
--- a/python/tests/avro/test_reader.py
+++ b/python/tests/avro/test_reader.py
@@ -17,10 +17,11 @@
 # pylint:disable=protected-access
 import io
 import json
+from typing import Type
 
 import pytest
 
-from pyiceberg.avro.decoder import BinaryDecoder
+from pyiceberg.avro.decoder import BinaryDecoder, InMemoryBinaryDecoder, 
StreamingBinaryDecoder
 from pyiceberg.avro.file import AvroFile
 from pyiceberg.avro.reader import (
     BinaryReader,
@@ -63,6 +64,8 @@ from pyiceberg.types import (
     UUIDType,
 )
 
+AVAILABLE_DECODERS = [StreamingBinaryDecoder, InMemoryBinaryDecoder]
+
 
 def test_read_header(generated_manifest_entry_file: str, 
iceberg_manifest_entry_schema: Schema) -> None:
     with AvroFile[ManifestEntry](
@@ -335,18 +338,19 @@ def test_uuid_reader() -> None:
     assert construct_reader(UUIDType()) == UUIDReader()
 
 
-def test_read_struct() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_struct(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x18")
-    decoder = BinaryDecoder(mis)
-
+    decoder = decoder_class(mis)
     struct = StructType(NestedField(1, "id", IntegerType(), required=True))
     result = StructReader(((0, IntegerReader()),), Record, 
struct).read(decoder)
     assert repr(result) == "Record[id=12]"
 
 
-def test_read_struct_lambda() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_struct_lambda(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x18")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
 
     struct = StructType(NestedField(1, "id", IntegerType(), required=True))
     # You can also pass in an arbitrary function that returns a struct
@@ -356,9 +360,10 @@ def test_read_struct_lambda() -> None:
     assert repr(result) == "Record[id=12]"
 
 
-def test_read_not_struct_type() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_not_struct_type(decoder_class: Type[BinaryDecoder]) -> None:
     mis = io.BytesIO(b"\x18")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
 
     struct = StructType(NestedField(1, "id", IntegerType(), required=True))
     with pytest.raises(ValueError) as exc_info:
@@ -367,9 +372,10 @@ def test_read_not_struct_type() -> None:
     assert "Incompatible with StructProtocol: <class 'str'>" in 
str(exc_info.value)
 
 
-def test_read_struct_exception_handling() -> None:
[email protected]("decoder_class", AVAILABLE_DECODERS)
+def test_read_struct_exception_handling(decoder_class: Type[BinaryDecoder]) -> 
None:
     mis = io.BytesIO(b"\x18")
-    decoder = BinaryDecoder(mis)
+    decoder = decoder_class(mis)
 
     def raise_err(struct: StructType) -> None:
         raise TypeError("boom")

Reply via email to