This is an automated email from the ASF dual-hosted git repository. honahx pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push: new 1016b197 Implement __getstate__ and __setstate__ so that FileIO instances can be pickled (#543) 1016b197 is described below commit 1016b197096472599c42228205f8a23f9ed28765 Author: Amogh Jahagirdar <am...@tabular.io> AuthorDate: Sun Apr 7 15:14:46 2024 -0600 Implement __getstate__ and __setstate__ so that FileIO instances can be pickled (#543) --- pyiceberg/io/fsspec.py | 12 ++++++++++++ pyiceberg/io/pyarrow.py | 12 ++++++++++++ tests/io/test_fsspec.py | 31 +++++++++++++++++++++++++++++++ tests/io/test_io.py | 20 ++++++++++++++++++++ 4 files changed, 75 insertions(+) diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py index 957cac66..ee97829c 100644 --- a/pyiceberg/io/fsspec.py +++ b/pyiceberg/io/fsspec.py @@ -20,6 +20,7 @@ import errno import json import logging import os +from copy import copy from functools import lru_cache, partial from typing import ( Any, @@ -338,3 +339,14 @@ class FsspecFileIO(FileIO): if scheme not in self._scheme_to_fs: raise ValueError(f"No registered filesystem for scheme: {scheme}") return self._scheme_to_fs[scheme](self.properties) + + def __getstate__(self) -> Dict[str, Any]: + """Create a dictionary of the FsSpecFileIO fields used when pickling.""" + fileio_copy = copy(self.__dict__) + fileio_copy["get_fs"] = None + return fileio_copy + + def __setstate__(self, state: Dict[str, Any]) -> None: + """Deserialize the state into a FsSpecFileIO instance.""" + self.__dict__ = state + self.get_fs = lru_cache(self._get_fs) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 06d03e21..1848fba7 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -33,6 +33,7 @@ import os import re from abc import ABC, abstractmethod from concurrent.futures import Future +from copy import copy from dataclasses import dataclass from enum import Enum from functools import lru_cache, singledispatch @@ -456,6 +457,17 @@ class PyArrowFileIO(FileIO): raise PermissionError(f"Cannot delete file, access denied: {location}") from e raise # pragma: no cover - If some other kind of OSError, raise the raw error + def __getstate__(self) -> Dict[str, Any]: + """Create a dictionary of the PyArrowFileIO fields used when pickling.""" + fileio_copy = copy(self.__dict__) + fileio_copy["fs_by_scheme"] = None + return fileio_copy + + def __setstate__(self, state: Dict[str, Any]) -> None: + """Deserialize the state into a PyArrowFileIO instance.""" + self.__dict__ = state + self.fs_by_scheme = lru_cache(self._initialize_fs) + def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema: return visit(schema, _ConvertToArrowSchema(metadata)) diff --git a/tests/io/test_fsspec.py b/tests/io/test_fsspec.py index 9f044454..013fc5d8 100644 --- a/tests/io/test_fsspec.py +++ b/tests/io/test_fsspec.py @@ -16,6 +16,7 @@ # under the License. import os +import pickle import tempfile import uuid @@ -229,6 +230,11 @@ def test_writing_avro_file(generated_manifest_entry_file: str, fsspec_fileio: Fs fsspec_fileio.delete(f"s3://warehouse/{filename}") +@pytest.mark.s3 +def test_fsspec_pickle_round_trip_s3(fsspec_fileio: FsspecFileIO) -> None: + _test_fsspec_pickle_round_trip(fsspec_fileio, "s3://warehouse/foo.txt") + + @pytest.mark.adlfs def test_fsspec_new_input_file_adlfs(adlfs_fsspec_fileio: FsspecFileIO) -> None: """Test creating a new input file from an fsspec file-io""" @@ -410,6 +416,11 @@ def test_writing_avro_file_adlfs(generated_manifest_entry_file: str, adlfs_fsspe adlfs_fsspec_fileio.delete(f"abfss://tests/{filename}") +@pytest.mark.adlfs +def test_fsspec_pickle_round_trip_aldfs(adlfs_fsspec_fileio: FsspecFileIO) -> None: + _test_fsspec_pickle_round_trip(adlfs_fsspec_fileio, "abfss://tests/foo.txt") + + @pytest.mark.gcs def test_fsspec_new_input_file_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None: """Test creating a new input file from a fsspec file-io""" @@ -586,6 +597,26 @@ def test_writing_avro_file_gcs(generated_manifest_entry_file: str, fsspec_fileio fsspec_fileio_gcs.delete(f"gs://warehouse/{filename}") +@pytest.mark.gcs +def test_fsspec_pickle_roundtrip_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None: + _test_fsspec_pickle_round_trip(fsspec_fileio_gcs, "gs://warehouse/foo.txt") + + +def _test_fsspec_pickle_round_trip(fsspec_fileio: FsspecFileIO, location: str) -> None: + serialized_file_io = pickle.dumps(fsspec_fileio) + deserialized_file_io = pickle.loads(serialized_file_io) + output_file = deserialized_file_io.new_output(location) + with output_file.create() as f: + f.write(b"foo") + + input_file = deserialized_file_io.new_input(location) + with input_file.open() as f: + data = f.read() + assert data == b"foo" + assert len(input_file) == 3 + deserialized_file_io.delete(location) + + TEST_URI = "https://iceberg-test-signer" diff --git a/tests/io/test_io.py b/tests/io/test_io.py index d27274e1..b273288b 100644 --- a/tests/io/test_io.py +++ b/tests/io/test_io.py @@ -16,6 +16,7 @@ # under the License. import os +import pickle import tempfile import pytest @@ -71,6 +72,25 @@ def test_custom_local_output_file() -> None: assert len(output_file) == 3 +def test_pickled_pyarrow_round_trip() -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + file_location = os.path.join(tmpdirname, "foo.txt") + file_io = PyArrowFileIO() + serialized_file_io = pickle.dumps(file_io) + deserialized_file_io = pickle.loads(serialized_file_io) + absolute_file_location = os.path.abspath(file_location) + output_file = deserialized_file_io.new_output(location=f"{absolute_file_location}") + with output_file.create() as f: + f.write(b"foo") + + input_file = deserialized_file_io.new_input(location=f"{absolute_file_location}") + f = input_file.open() + data = f.read() + assert data == b"foo" + assert len(input_file) == 3 + deserialized_file_io.delete(location=f"{absolute_file_location}") + + def test_custom_local_output_file_with_overwrite() -> None: """Test initializing an OutputFile implementation to overwrite a local file""" with tempfile.TemporaryDirectory() as tmpdirname: