This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 1016b197 Implement __getstate__ and __setstate__ so that FileIO
instances can be pickled (#543)
1016b197 is described below
commit 1016b197096472599c42228205f8a23f9ed28765
Author: Amogh Jahagirdar <[email protected]>
AuthorDate: Sun Apr 7 15:14:46 2024 -0600
Implement __getstate__ and __setstate__ so that FileIO instances can be
pickled (#543)
---
pyiceberg/io/fsspec.py | 12 ++++++++++++
pyiceberg/io/pyarrow.py | 12 ++++++++++++
tests/io/test_fsspec.py | 31 +++++++++++++++++++++++++++++++
tests/io/test_io.py | 20 ++++++++++++++++++++
4 files changed, 75 insertions(+)
diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py
index 957cac66..ee97829c 100644
--- a/pyiceberg/io/fsspec.py
+++ b/pyiceberg/io/fsspec.py
@@ -20,6 +20,7 @@ import errno
import json
import logging
import os
+from copy import copy
from functools import lru_cache, partial
from typing import (
Any,
@@ -338,3 +339,14 @@ class FsspecFileIO(FileIO):
if scheme not in self._scheme_to_fs:
raise ValueError(f"No registered filesystem for scheme: {scheme}")
return self._scheme_to_fs[scheme](self.properties)
+
+ def __getstate__(self) -> Dict[str, Any]:
+ """Create a dictionary of the FsSpecFileIO fields used when
pickling."""
+ fileio_copy = copy(self.__dict__)
+ fileio_copy["get_fs"] = None
+ return fileio_copy
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ """Deserialize the state into a FsSpecFileIO instance."""
+ self.__dict__ = state
+ self.get_fs = lru_cache(self._get_fs)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 06d03e21..1848fba7 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -33,6 +33,7 @@ import os
import re
from abc import ABC, abstractmethod
from concurrent.futures import Future
+from copy import copy
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, singledispatch
@@ -456,6 +457,17 @@ class PyArrowFileIO(FileIO):
raise PermissionError(f"Cannot delete file, access denied:
{location}") from e
raise # pragma: no cover - If some other kind of OSError, raise
the raw error
+ def __getstate__(self) -> Dict[str, Any]:
+ """Create a dictionary of the PyArrowFileIO fields used when
pickling."""
+ fileio_copy = copy(self.__dict__)
+ fileio_copy["fs_by_scheme"] = None
+ return fileio_copy
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ """Deserialize the state into a PyArrowFileIO instance."""
+ self.__dict__ = state
+ self.fs_by_scheme = lru_cache(self._initialize_fs)
+
def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata:
Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata))
diff --git a/tests/io/test_fsspec.py b/tests/io/test_fsspec.py
index 9f044454..013fc5d8 100644
--- a/tests/io/test_fsspec.py
+++ b/tests/io/test_fsspec.py
@@ -16,6 +16,7 @@
# under the License.
import os
+import pickle
import tempfile
import uuid
@@ -229,6 +230,11 @@ def test_writing_avro_file(generated_manifest_entry_file:
str, fsspec_fileio: Fs
fsspec_fileio.delete(f"s3://warehouse/{filename}")
[email protected]
+def test_fsspec_pickle_round_trip_s3(fsspec_fileio: FsspecFileIO) -> None:
+ _test_fsspec_pickle_round_trip(fsspec_fileio, "s3://warehouse/foo.txt")
+
+
@pytest.mark.adlfs
def test_fsspec_new_input_file_adlfs(adlfs_fsspec_fileio: FsspecFileIO) ->
None:
"""Test creating a new input file from an fsspec file-io"""
@@ -410,6 +416,11 @@ def
test_writing_avro_file_adlfs(generated_manifest_entry_file: str, adlfs_fsspe
adlfs_fsspec_fileio.delete(f"abfss://tests/{filename}")
[email protected]
+def test_fsspec_pickle_round_trip_aldfs(adlfs_fsspec_fileio: FsspecFileIO) ->
None:
+ _test_fsspec_pickle_round_trip(adlfs_fsspec_fileio,
"abfss://tests/foo.txt")
+
+
@pytest.mark.gcs
def test_fsspec_new_input_file_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None:
"""Test creating a new input file from a fsspec file-io"""
@@ -586,6 +597,26 @@ def
test_writing_avro_file_gcs(generated_manifest_entry_file: str, fsspec_fileio
fsspec_fileio_gcs.delete(f"gs://warehouse/{filename}")
[email protected]
+def test_fsspec_pickle_roundtrip_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None:
+ _test_fsspec_pickle_round_trip(fsspec_fileio_gcs, "gs://warehouse/foo.txt")
+
+
+def _test_fsspec_pickle_round_trip(fsspec_fileio: FsspecFileIO, location: str)
-> None:
+ serialized_file_io = pickle.dumps(fsspec_fileio)
+ deserialized_file_io = pickle.loads(serialized_file_io)
+ output_file = deserialized_file_io.new_output(location)
+ with output_file.create() as f:
+ f.write(b"foo")
+
+ input_file = deserialized_file_io.new_input(location)
+ with input_file.open() as f:
+ data = f.read()
+ assert data == b"foo"
+ assert len(input_file) == 3
+ deserialized_file_io.delete(location)
+
+
TEST_URI = "https://iceberg-test-signer"
diff --git a/tests/io/test_io.py b/tests/io/test_io.py
index d27274e1..b273288b 100644
--- a/tests/io/test_io.py
+++ b/tests/io/test_io.py
@@ -16,6 +16,7 @@
# under the License.
import os
+import pickle
import tempfile
import pytest
@@ -71,6 +72,25 @@ def test_custom_local_output_file() -> None:
assert len(output_file) == 3
+def test_pickled_pyarrow_round_trip() -> None:
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ file_location = os.path.join(tmpdirname, "foo.txt")
+ file_io = PyArrowFileIO()
+ serialized_file_io = pickle.dumps(file_io)
+ deserialized_file_io = pickle.loads(serialized_file_io)
+ absolute_file_location = os.path.abspath(file_location)
+ output_file =
deserialized_file_io.new_output(location=f"{absolute_file_location}")
+ with output_file.create() as f:
+ f.write(b"foo")
+
+ input_file =
deserialized_file_io.new_input(location=f"{absolute_file_location}")
+ f = input_file.open()
+ data = f.read()
+ assert data == b"foo"
+ assert len(input_file) == 3
+ deserialized_file_io.delete(location=f"{absolute_file_location}")
+
+
def test_custom_local_output_file_with_overwrite() -> None:
"""Test initializing an OutputFile implementation to overwrite a local
file"""
with tempfile.TemporaryDirectory() as tmpdirname: